Polished Grouped GEMM APIs and new BF16 instances (#1600)

* Few small fixes.

* New GroupedGemm instances (BF16)

* Unify and refactor GroupedGEMM device API.

* Adapt changes to new API.

* Adapt grouped gemm profiler.

* Accept multiple kbatches for grouped gemm profiler.

- delete obsolete two stage as it is now covered by grouped gemm

* Update unit test for grouped gemm.

* Fix thresholds for BF16 and F8. Unblock tests.

* Fix few instances.

* Multiple small fixes.

* Adapt to new API, check dynamic casting.

* Uncomment few data types in grouped gemm profiler.

* Fix call to SetDeviceArgs.

* Fix profile grouped gemm multiply tile loop.

* Fix grouped gemm tile loop kernel args in client examples.

* Review comments.
This commit is contained in:
Adam Osewski
2024-11-27 13:02:44 +01:00
committed by GitHub
parent cb8c7f42d6
commit 061ac0649c
51 changed files with 1400 additions and 1723 deletions

View File

@@ -121,7 +121,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
constexpr ck::index_t NumDTensor = 2;
using GroupedGemmKernelArgument =
ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments<NumDTensor>;
ck::tensor_operation::device::GroupedGemmKernelArgument<NumDTensor>;
std::vector<GroupedGemmKernelArgument> grouped_gemm_kernel_args_;
grouped_gemm_kernel_args_.reserve(group_count);

View File

@@ -120,7 +120,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
constexpr ck::index_t NumDTensor = 1;
using GroupedGemmKernelArgument =
ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments<NumDTensor>;
ck::tensor_operation::device::GroupedGemmKernelArgument<NumDTensor>;
std::vector<GroupedGemmKernelArgument> grouped_gemm_kernel_args_;
grouped_gemm_kernel_args_.reserve(group_count);

View File

@@ -246,7 +246,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
// do GEMM
auto argument = gemm.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op);
gemm.SetKBatchSize(argument, config.k_batch);
gemm.SetKBatchSize(&argument, config.k_batch);
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
@@ -257,7 +257,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer());
DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument));
gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer());
gemm.SetDeviceKernelArgs(&argument, gemm_arg_dev_mem.GetDeviceBuffer());
invoker.Run(argument, StreamConfig{nullptr, false, 1});

View File

@@ -91,7 +91,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
{
auto group_count = problem_size.group_count;
using KernelArguments = ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments<NumDs>;
using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument<NumDs>;
using GemmDesc = ck::tensor_operation::device::GemmDesc;
// GEMM shape

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
@@ -254,7 +254,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
gemm.GetDeviceKernelArgSize(&argument),
hipMemcpyHostToDevice));
gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer());
gemm.SetDeviceKernelArgs(&argument, gemm_kernel_args_dev.GetDeviceBuffer());
gemm.SetKBatch(argument, config.k_batch);
invoker.Run(argument, StreamConfig{nullptr, false});

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
@@ -239,7 +239,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
"not support this GEMM problem");
}
gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer());
gemm.SetDeviceKernelArgs(&argument, gemm_arg_dev_mem.GetDeviceBuffer());
gemm.SetKBatch(argument, config.k_batch);
invoker.Run(argument, StreamConfig{nullptr, false});

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
@@ -240,7 +240,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
"not support this GEMM problem");
}
gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer());
gemm.SetDeviceKernelArgs(&argument, gemm_arg_dev_mem.GetDeviceBuffer());
gemm.SetKBatch(argument, config.k_batch);
invoker.Run(argument, StreamConfig{nullptr, false});

View File

@@ -168,9 +168,23 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
auto argument = gemm.MakeArgument(
p_a, p_b, p_Ds, p_c, gemm_descs, a_element_op, b_element_op, c_element_op);
DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument));
std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument);
std::size_t kargs_size = gemm.GetDeviceKernelArgSize(&argument);
gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer());
DeviceMem gemm_workspace, gemm_kargs;
// The following is necessary since TwoStage kernel is using additional memory both
// for Workspace and kernel arguments.
if(kargs_size > 0)
{
gemm_kargs.Realloc(kargs_size);
gemm.SetDeviceKernelArgs(&argument, gemm_kargs.GetDeviceBuffer());
}
if(workspace_size > 0 && workspace_size != kargs_size)
{
gemm_workspace.Realloc(workspace_size);
gemm.SetWorkSpacePointer(&argument, gemm_workspace.GetDeviceBuffer());
}
if(!gemm.IsSupportedArgument(argument))
{

View File

@@ -1,17 +1,87 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <iostream>
#include <sstream>
#include <stdexcept>
#include <vector>
#include "device_base.hpp"
#include "ck/utility/ignore.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
///
/// @brief Structure representing single GEMM problem arguments.
///
/// The pointer to the vector of those structures is passed to the GroupedGEMM entry
/// point kernel.
///
/// @tparam NumDTensor The number of D input tensors.
///
template <index_t NumDTensor = 0>
struct GroupedGemmKernelArgument
{
__host__ __device__ GroupedGemmKernelArgument(const void* p_a_grid_,
const void* p_b_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
void* p_e_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_ds_grid{p_ds_grid_},
p_e_grid{p_e_grid_},
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideDs{StrideDs_},
StrideE{StrideE_}
{
}
const void* p_a_grid;
const void* p_b_grid;
std::array<const void*, NumDTensor> p_ds_grid;
void* p_e_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
void Print() const
{
std::stringstream str;
for(auto sd : StrideDs)
str << sd << ",";
std::cout << "arg {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SE:" << StrideE << ", "
<< "SDs: {" << str.str() << "}"
<< "}" << std::endl;
}
};
struct GemmDesc
{
ck::index_t M_, N_, K_;
@@ -48,6 +118,66 @@ struct DeviceGroupedGemm : public BaseOperator
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
//---------------------------------------------------------------------------------------------
/// @brief Sets the device kernel arguments pointer and may copy data to device.
///
/// TODO: Add which kernels are using this (TileLoop * FixedNK ??)
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_dev_kernel_args The pointer to the device memory which will contain kernel
/// arguments.
/// @param[in] p_host_kernel_args The pointer to the host memory which contains kernel
/// arguments that should be copied to device memory.
///
virtual void SetDeviceKernelArgs(BaseArgument* p_arg,
void* p_dev_kernel_args,
const void* p_host_kernel_args) const
{
ignore = p_arg;
ignore = p_dev_kernel_args;
ignore = p_host_kernel_args;
std::ostringstream err;
err << "This function is not implemented by the kernel: " << this->GetTypeString()
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
//----------------------------------------------------------------------------------------------
/// @brief Sets the device kernel arguments pointer and may copy data to device.
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
/// arguments.
///
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const
{
ignore = p_arg;
ignore = p_dev_kernel_args;
std::ostringstream err;
err << "This function is not implemented by the kernel: " << this->GetTypeString()
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
//----------------------------------------------------------------------------------------------
/// @brief Gets the device kernel argument size.
///
/// @param[in] p_arg The pointer to the Device op Argument.
///
/// @return The device kernel argument size.
///
virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const
{
ignore = p_arg;
std::ostringstream err;
err << "This function is not implemented by the kernel: " << this->GetTypeString()
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
};
} // namespace device

View File

@@ -1,35 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <array>
#include "device_grouped_gemm.hpp"
#include "device_grouped_gemm_splitk.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <index_t NumDTensor = 0>
struct GroupedGemmKernelArgument
{
const void* p_a_grid;
const void* p_b_grid;
std::array<const void*, NumDTensor> p_ds_grid;
void* p_e_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
};
template <typename ALayout,
typename BLayout,
typename DsLayout,
@@ -41,21 +20,18 @@ template <typename ALayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGroupedGemmFixedNK : DeviceGroupedGemm<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
struct DeviceGroupedGemmFixedNK : DeviceGroupedGemmSplitK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const = 0;
virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0;
virtual void SetKBatch(BaseArgument* p_arg, index_t k_batch) const = 0;
};
} // namespace device

View File

@@ -1,136 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <iostream>
#include <vector>
#include <sstream>
#include "device_grouped_gemm.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
///
/// @brief Structure representing single GEMM problem arguments.
///
/// The pointer to the vector of those structures is passed to the GroupedGEMM entry
/// point kernel.
///
/// @tparam NumDTensor The number of D input tensors.
///
template <index_t NumDTensor = 0>
struct GroupedGemmMultipleDKernelArguments
{
__host__ __device__
GroupedGemmMultipleDKernelArguments(const void* p_a_grid_,
const void* p_b_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
void* p_e_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_ds_grid{p_ds_grid_},
p_e_grid{p_e_grid_},
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideDs{StrideDs_},
StrideE{StrideE_}
{
}
const void* p_a_grid;
const void* p_b_grid;
std::array<const void*, NumDTensor> p_ds_grid;
void* p_e_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
void Print() const
{
std::stringstream str;
for(auto sd : StrideDs)
str << sd << ",";
std::cout << "arg {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SE:" << StrideE << ", "
<< "SDs: {" << str.str() << "}"
<< "}" << std::endl;
}
};
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGroupedGemmMultipleDSplitK : public DeviceGroupedGemm<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
//----------------------------------------------------------------------------------------------
/// @brief Sets the k batch size.
///
/// @param p_arg Pointer to the Argument we're going to change.
/// @param[in] kbatch The kbatch value.
///
virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0;
//----------------------------------------------------------------------------------------------
/// @brief Sets the device kernel arguments pointer.
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
/// arguments.
///
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const = 0;
//----------------------------------------------------------------------------------------------
/// @brief Gets the device kernel argument size.
///
/// @param[in] p_arg The pointer to the Device op Argument.
///
/// @return The device kernel argument size.
///
virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,6 +1,6 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_grouped_gemm.hpp"
@@ -31,7 +31,23 @@ struct DeviceGroupedGemmSplitK : public DeviceGroupedGemm<ALayout,
BElementwiseOperation,
CElementwiseOperation>
{
//----------------------------------------------------------------------------------------------
/// @brief Sets the k batch size.
///
/// @param p_arg Pointer to the Argument we're going to change.
/// @param[in] kbatch The kbatch value.
///
virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0;
//----------------------------------------------------------------------------------------------
/// @brief Sets the k batch size.
///
/// @param p_arg Pointer to the Argument we're going to change.
/// @param[in] kbatch The kbatch value.
///
virtual void SetKBatch(BaseArgument* p_arg, index_t kbatch) const
{
this->SetKBatchSize(p_arg, kbatch);
};
};
} // namespace device

View File

@@ -3,83 +3,20 @@
#pragma once
#include <array>
#include <iostream>
#include <vector>
#include <sstream>
#include "device_grouped_gemm.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
/// @brief Grouped GEMM kernel using output Tile Looping algorithm
///
/// @brief Structure representing single GEMM problem arguments.
/// @par This kernel does not require any knowledge about input data sizes (GEMM M/N/K)
/// It requires only the number of groups to launch. Other information like
/// data pointers and GEMM sizes, packed into gemm kernel args may be all dynamic
/// (known only at kernel run-time).
///
/// The pointer to the vector of those structures is passed to the GroupedGEMM entry
/// point kernel.
///
/// @tparam NumDTensor The number of D input tensors.
///
template <index_t NumDTensor = 0>
struct GroupedGemmTileLoopKernelArguments
{
__host__ __device__
GroupedGemmTileLoopKernelArguments(const void* p_a_grid_,
const void* p_b_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
void* p_e_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_ds_grid{p_ds_grid_},
p_e_grid{p_e_grid_},
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideDs{StrideDs_},
StrideE{StrideE_}
{
}
const void* p_a_grid;
const void* p_b_grid;
std::array<const void*, NumDTensor> p_ds_grid;
void* p_e_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
void Print() const
{
std::stringstream str;
for(auto sd : StrideDs)
str << sd << ",";
std::cout << "arg {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SE:" << StrideE << ", "
<< "SDs: {" << str.str() << "}"
<< "}" << std::endl;
}
};
/// @note This kernel does not support SplitK.
template <typename ALayout,
typename BLayout,
@@ -104,23 +41,6 @@ struct DeviceGroupedGemmTileLoop : public DeviceGroupedGemm<ALayout,
BElementwiseOperation,
CDEElementwiseOperation>
{
//----------------------------------------------------------------------------------------------
/// @brief Sets the device kernel arguments pointer.
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
/// arguments.
///
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const = 0;
//----------------------------------------------------------------------------------------------
/// @brief Gets the device kernel argument size.
///
/// @param[in] p_arg The pointer to the Device op Argument.
///
/// @return The device kernel argument size.
///
virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0;
};
} // namespace device

View File

@@ -18,7 +18,6 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
@@ -78,17 +77,17 @@ template <typename ALayout,
// TODO: change gridwise_gemm_v2r4r2 to support AK1 & BK1
enable_if_t<AK1 == BK1, bool> = false>
struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
: public DeviceGroupedGemmMultipleDSplitK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
: public DeviceGroupedGemmSplitK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage;
@@ -530,7 +529,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
index_t skipped_group_count_;
index_t grid_size_;
// Pointer to device memory with GEMM kernel arguments.
const void* p_dev_gemm_args_;
void* p_dev_gemm_kargs_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
@@ -566,7 +565,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
/// @return The average kernel execution time (if time measurement is enabled.)
///
float Run(const Argument& arg,
const void* dev_gemm_args,
void* dev_gemm_args,
void* dev_gemm_workspace,
const StreamConfig& stream_config = StreamConfig{})
{
@@ -621,7 +620,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
///
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(arg.p_dev_gemm_args_ == nullptr)
if(arg.p_dev_gemm_kargs_ == nullptr)
{
std::ostringstream err;
err << "The gemm arguments device buffer is not allocated!"
@@ -637,7 +636,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
throw std::runtime_error(err.str());
}
return Run(arg, arg.p_dev_gemm_args_, arg.p_workspace_, stream_config);
return Run(arg, arg.p_dev_gemm_kargs_, arg.p_workspace_, stream_config);
}
float Run(const BaseArgument* p_arg,
@@ -723,7 +722,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
template <bool HasMainKBlockLoop>
float DispatchKernel(const Argument& arg,
const void* dev_gemm_args,
void* dev_gemm_kargs,
void* dev_gemm_workspace,
const StreamConfig& stream_config) const
{
@@ -746,7 +745,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
return LaunchKernel(gemm_kernel,
elementwise_kernel,
arg,
dev_gemm_args,
dev_gemm_kargs,
dev_gemm_workspace,
stream_config);
}
@@ -755,12 +754,19 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
float LaunchKernel(const KernelFunction& gemm_kernel,
const KernelFunction2& elementwise_kernel,
const Argument& arg,
const void* dev_gemm_args,
void* dev_gemm_kargs,
[[maybe_unused]] void* dev_gemm_workspace,
const StreamConfig& stream_config) const
{
float time{0.f};
hip_check_error(
hipMemcpyWithStream(dev_gemm_kargs,
arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg),
hipMemcpyHostToDevice,
stream_config.stream_id_));
auto preprocess = [&]() {
hip_check_error(hipMemsetAsync(
dev_gemm_workspace, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_));
@@ -774,7 +780,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(dev_gemm_args),
cast_pointer_to_constant_address_space(dev_gemm_kargs),
arg.gemm_kernel_args_.size(),
arg.a_element_op_,
arg.b_element_op_,
@@ -930,18 +936,30 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
return str.str();
}
void SetDeviceKernelArgs(Argument& arg, void* p_dev_kernel_args) const
{
arg.p_dev_gemm_args_ = p_dev_kernel_args;
hip_check_error(hipMemcpy(p_dev_kernel_args,
arg.gemm_kernel_args_.data(),
GetDeviceKernelArgSize(&arg),
hipMemcpyHostToDevice));
}
void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
{
return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), p_dev_kernel_args);
auto arg_ptr = dynamic_cast<Argument*>(p_arg);
if(arg_ptr)
{
arg_ptr->p_dev_gemm_kargs_ = p_dev_kernel_args;
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!");
}
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
{
auto arg = dynamic_cast<const Argument*>(p_arg);
if(arg)
{
return arg->gemm_kernel_args_.size() * sizeof(GemmTransKernelArg);
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!");
}
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
@@ -974,17 +992,22 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
"DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!");
}
static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); }
[[deprecated]] static void SetKBatchSize(Argument& arg, index_t kbatch)
{
arg.UpdateKBatch(kbatch);
}
void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override
{
return SetKBatchSize(*dynamic_cast<Argument*>(p_arg), kbatch);
}
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
{
return dynamic_cast<const Argument*>(p_arg)->gemm_kernel_args_.size() *
sizeof(GemmTransKernelArg);
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
if(p_arg_)
{
p_arg_->UpdateKBatch(kbatch);
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!");
}
};

View File

@@ -20,7 +20,6 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" // stare wywalic
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
namespace ck {
@@ -522,7 +521,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
ComputeTypeA,
ComputeTypeB>;
using KernelArguments = GroupedGemmTileLoopKernelArguments<NumDTensor>;
using KernelArguments = GroupedGemmKernelArgument<NumDTensor>;
using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
using OffsettedLocalBlock2ETileMap = OffsettedBlockToCTileMap2<Block2ETileMap>;
@@ -936,12 +935,31 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
return str.str();
}
void SetDeviceKernelArgs(Argument& arg,
void* p_dev_kernel_args,
const void* p_host_kernel_args) const
{
arg.p_dev_gemm_args_ = p_dev_kernel_args;
hip_check_error(hipMemcpy(p_dev_kernel_args,
p_host_kernel_args,
GetDeviceKernelArgSize(&arg),
hipMemcpyHostToDevice));
}
virtual void SetDeviceKernelArgs(BaseArgument* p_arg,
void* p_dev_kernel_args,
const void* p_host_kernel_args) const override
{
return SetDeviceKernelArgs(
*dynamic_cast<Argument*>(p_arg), p_dev_kernel_args, p_host_kernel_args);
}
void SetDeviceKernelArgs(Argument& arg, void* p_dev_kernel_args) const
{
arg.p_dev_gemm_args_ = p_dev_kernel_args;
}
void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
{
return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), p_dev_kernel_args);
}

View File

@@ -1,6 +1,6 @@
#pragma once
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -717,7 +717,24 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GemmBiasTransKernelArg);
auto p_arg_ = dynamic_cast<const Argument*>(p_arg);
if(p_arg_)
{
return p_arg_->group_count_ * sizeof(GemmBiasTransKernelArg);
}
else
throw std::runtime_error("The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDXdlCShuffle::Argument structure!");
}
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
{
return GetWorkSpaceSize(p_arg);
}
void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
{
return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args);
}
};

View File

@@ -445,6 +445,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops<MPerBlock, NPerBlock>;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops<Block2ETileMap>;
// TODO: replace with GroupedGemmKernelArgument
struct GemmBiasTransKernelArg
{
// pointers
@@ -900,40 +901,58 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
return str.str();
}
static void SetDeviceKernelArgs(Argument& arg, const void* kernel_args)
{
arg.grouped_gemm_kernel_args_dev = kernel_args;
}
// polymorphic
void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const override
void SetDeviceKernelArgs(BaseArgument* p_arg, void* kernel_args) const override
{
return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), kernel_args);
auto arg_ptr = dynamic_cast<Argument*>(p_arg);
if(arg_ptr)
{
arg_ptr->grouped_gemm_kernel_args_dev = kernel_args;
}
else
throw std::runtime_error("The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!");
}
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
auto arg = *dynamic_cast<const Argument*>(p_arg);
return arg.group_count_ * arg.barrier_size_grp_ * sizeof(uint32_t);
auto arg_ptr = dynamic_cast<const Argument*>(p_arg);
if(arg_ptr)
{
return arg_ptr->group_count_ * arg_ptr->barrier_size_grp_ * sizeof(uint32_t);
}
else
throw std::runtime_error("The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!");
}
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
{
auto arg = *dynamic_cast<const Argument*>(p_arg);
return arg.group_count_ * sizeof(GroupedGemmKernelArgument<NumDTensor>);
auto arg_ptr = dynamic_cast<const Argument*>(p_arg);
if(arg_ptr)
{
return arg_ptr->group_count_ * sizeof(GroupedGemmKernelArgument<NumDTensor>);
}
else
throw std::runtime_error("The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!");
}
void SetWorkSpacePointer(BaseArgument* p_arg,
void* p_workspace,
const StreamConfig& stream_config = StreamConfig{}) const override
{
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
p_arg_->p_workspace_ = p_workspace;
auto arg_ptr = dynamic_cast<Argument*>(p_arg);
if(arg_ptr)
{
arg_ptr->p_workspace_ = p_workspace;
}
else
throw std::runtime_error("The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!");
hip_check_error(
hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(p_arg), stream_config.stream_id_));
hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(arg_ptr), stream_config.stream_id_));
}
static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); }
@@ -941,7 +960,26 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
// polymorphic
void SetKBatch(BaseArgument* p_arg, index_t k_batch) const override
{
return SetKBatch(*dynamic_cast<Argument*>(p_arg), k_batch);
auto arg_ptr = dynamic_cast<Argument*>(p_arg);
if(arg_ptr)
{
arg_ptr->UpdateKBatch(k_batch);
}
else
throw std::runtime_error("The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!");
}
void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override
{
auto arg_ptr = dynamic_cast<Argument*>(p_arg);
if(arg_ptr)
{
arg_ptr->UpdateKBatch(kbatch);
}
else
throw std::runtime_error("The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!");
}
};

View File

@@ -546,7 +546,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
bool supported = true;
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{
const auto& a = arg.gemm_kernel_args_[i].karg_;
const auto& a = arg.gemm_kernel_args_[i].karg_;
bool group_arg_valid = GridwiseGemm::CheckValidity(a);
if(not group_arg_valid)
{
@@ -636,16 +637,42 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
return dynamic_cast<const Argument*>(p_arg)->gemm_kernel_args_.size() *
sizeof(GemmTransKernelArg);
auto p_arg_ = dynamic_cast<const Argument*>(p_arg);
if(p_arg_)
{
return p_arg_->gemm_kernel_args_.size() * sizeof(GemmTransKernelArg);
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffle::Argument structure!");
}
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
{
return GetWorkSpaceSize(p_arg);
}
// TODO: deperecation notice.
static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); }
// polymorphic
void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override
{
return SetKBatchSize(*dynamic_cast<Argument*>(p_arg), kbatch);
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
if(p_arg_)
{
p_arg_->UpdateKBatch(kbatch);
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffle::Argument structure!");
}
void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
{
return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args);
}
};

View File

@@ -5,7 +5,6 @@
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace ck {

View File

@@ -95,6 +95,45 @@ void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv1_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv1_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv2_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
@@ -189,6 +228,124 @@ void add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_bf16_bf16_mk_nk_mn_in
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv1_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv1_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv2_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv1_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
Empty_Tuple,
Row,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv1_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
Empty_Tuple,
Row,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv2_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
Empty_Tuple,
Row,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv1_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Row,
Empty_Tuple,
Row,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv1_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Row,
Empty_Tuple,
Row,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv2_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Row,
Empty_Tuple,
Row,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8)
@@ -262,7 +419,11 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances(
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv1_inter_instances(
op_ptrs);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv1_instances(
op_ptrs);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv2_instances(
op_ptrs);
add_device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_instances(
op_ptrs);
@@ -334,12 +495,34 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
{
add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_bf16_bf16_mk_kn_mn_instances(
op_ptrs);
add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv1_inter_instances(
op_ptrs);
add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv1_instances(
op_ptrs);
add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv2_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_bf16_bf16_mk_nk_mn_instances(
op_ptrs);
add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv1_inter_instances(
op_ptrs);
add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv1_instances(
op_ptrs);
add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv2_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv1_inter_instances(
op_ptrs);
add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv1_instances(
op_ptrs);
add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv2_instances(
op_ptrs);
}
}
#endif

View File

@@ -0,0 +1,138 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/utility/loop_scheduler.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Empty_Tuple = ck::Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto PipelineV1 = ck::PipelineVersion::v1;
static constexpr auto PipelineV2 = ck::PipelineVersion::v2;
static constexpr auto DefaultScheduler = ck::LoopScheduler::Default;
static constexpr auto InterwaveScheduler = ck::LoopScheduler::Interwave;
static constexpr auto GemmMNKPadding = device::GemmSpecialization::MNKPadding;
static constexpr auto GemmDefault = device::GemmSpecialization::Default;
template <typename T,
device::GemmSpecialization GemmSpec = GemmMNKPadding,
PipelineVersion Pipeline = PipelineV1,
LoopScheduler Scheduler = DefaultScheduler,
enable_if_t<sizeof(T) == 2, bool> = false>
using device_grouped_gemm_xdl_splitk_2Bt_rrr_instances = std::tuple<
// clang-format off
//################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline | Loop |
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Version | Scheduler |
//################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Pipeline, Scheduler>
// clang-format on
>;
template <typename T,
device::GemmSpecialization GemmSpec = GemmMNKPadding,
PipelineVersion Pipeline = PipelineV1,
LoopScheduler Scheduler = DefaultScheduler,
enable_if_t<sizeof(T) == 2, bool> = false>
using device_grouped_gemm_xdl_splitk_2Bt_rcr_instances = std::tuple<
// clang-format off
//################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline | Loop |
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Version | Scheduler |
//################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Pipeline, Scheduler>
// clang-format on
>;
template <typename T,
device::GemmSpecialization GemmSpec = GemmMNKPadding,
PipelineVersion Pipeline = PipelineV1,
LoopScheduler Scheduler = DefaultScheduler,
enable_if_t<sizeof(T) == 2, bool> = false>
using device_grouped_gemm_xdl_splitk_2Bt_crr_instances = std::tuple<
// clang-format off
//################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline | Loop |
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Version | Scheduler |
//################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 2, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Pipeline, Scheduler>,
DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Pipeline, Scheduler>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -4,12 +4,30 @@ add_instance_library(device_grouped_gemm_instance
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp
device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv1_inter.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv1.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv2.cpp
device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv1_inter.cpp
device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv1.cpp
device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv2.cpp
device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv1_inter.cpp
device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv1.cpp
device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv2.cpp
device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv1_inter.cpp
device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv1.cpp
device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv2.cpp
device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_multiple_d_splitk_xdl_two_stage_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_bf16_bf16_mk_kn_mn_instance.cpp
device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_bf16_bf16_mk_nk_mn_instance.cpp

View File

@@ -0,0 +1,32 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv1_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Row,
Empty_Tuple,
Row,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances, device_grouped_gemm_xdl_splitk_2Bt_crr_instances<BF16, GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,36 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv1_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Row,
Empty_Tuple,
Row,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_splitk_2Bt_crr_instances<BF16,
GemmMNKPadding,
PipelineV1,
InterwaveScheduler>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,33 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv2_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Row,
Empty_Tuple,
Row,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_splitk_2Bt_crr_instances<BF16, GemmMNKPadding, PipelineV2>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,32 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv1_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances, device_grouped_gemm_xdl_splitk_2Bt_rrr_instances<BF16, GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,36 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv1_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_splitk_2Bt_rrr_instances<BF16,
GemmMNKPadding,
PipelineV1,
InterwaveScheduler>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,38 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv2_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_splitk_2Bt_rrr_instances<BF16, GemmMNKPadding, PipelineV2>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,32 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv1_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
Empty_Tuple,
Row,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances, device_grouped_gemm_xdl_splitk_2Bt_rcr_instances<BF16, GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,36 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv1_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
Empty_Tuple,
Row,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_splitk_2Bt_rcr_instances<BF16,
GemmMNKPadding,
PipelineV1,
InterwaveScheduler>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,33 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv2_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
Empty_Tuple,
Row,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_splitk_2Bt_rcr_instances<BF16, GemmMNKPadding, PipelineV2>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,53 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Empty_Tuple = ck::Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// a[m, k] * b[k, n] = e[m, n]
using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple<
// clang-format off
//################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on
>;
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
@@ -61,8 +22,8 @@ void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances{});
add_device_operation_instances(
instances, device_grouped_gemm_xdl_splitk_2Bt_rrr_instances<F16, GemmDefault>{});
}
} // namespace instance

View File

@@ -1,123 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Empty_Tuple = ck::Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_tile_instances = std::tuple<
// clang-format off
//################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v1>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v1>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v1>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v1>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v1>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v1>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v1>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v1>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v1>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v1>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v1>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v1>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v1>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1, LoopScheduler::Interwave>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1, LoopScheduler::Interwave>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1, LoopScheduler::Interwave>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1, LoopScheduler::Interwave>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1, LoopScheduler::Interwave>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1, LoopScheduler::Interwave>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1, LoopScheduler::Interwave>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v1, LoopScheduler::Interwave>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v1, LoopScheduler::Interwave>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v1, LoopScheduler::Interwave>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v1, LoopScheduler::Interwave>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v1, LoopScheduler::Interwave>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v1, LoopScheduler::Interwave>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v1, LoopScheduler::Interwave>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v1, LoopScheduler::Interwave>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v1, LoopScheduler::Interwave>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v1, LoopScheduler::Interwave>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v1, LoopScheduler::Interwave>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v1, LoopScheduler::Interwave>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v1, LoopScheduler::Interwave>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v2>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v2>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v2>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v2>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v2>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v2>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v2>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v2>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v2>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v2>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v2>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v2>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v2>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v2>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v2>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v2>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v2>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v2>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v2>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v2>
// clang-format on
>;
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances, device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_tile_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,32 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv1_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances, device_grouped_gemm_xdl_splitk_2Bt_rrr_instances<F16, GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,36 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv1_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_splitk_2Bt_rrr_instances<F16,
GemmMNKPadding,
PipelineV1,
InterwaveScheduler>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,33 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv2_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_splitk_2Bt_rrr_instances<F16, GemmMNKPadding, PipelineV2>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,57 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Empty_Tuple = ck::Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// a[m, k] * b[n, k] = e[m, n]
using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances = std::tuple<
// clang-format off
//################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
// clang-format on
>;
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
@@ -65,8 +22,8 @@ void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances{});
add_device_operation_instances(
instances, device_grouped_gemm_xdl_splitk_2Bt_rcr_instances<F16, GemmDefault>{});
}
} // namespace instance

View File

@@ -1,63 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Empty_Tuple = ck::Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple<
// clang-format off
//################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 32, 8, 8, 32, 32, 1, 4, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
// clang-format on
>;
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
@@ -72,7 +23,7 @@ void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances
PassThrough>>>& instances)
{
add_device_operation_instances(
instances, device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_tile_instances{});
instances, device_grouped_gemm_xdl_splitk_2Bt_rcr_instances<F16, GemmMNKPadding>{});
}
} // namespace instance

View File

@@ -1,234 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
template <typename DsLayout,
typename DsDataType,
typename CDEElementwiseOp,
GemmSpecialization GemmSpec = GemmMNKPadding>
using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances = std::tuple<
// clang-format off
//###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S<C,D0...,D_N|
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 224, 256, 64, 8, 4, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 2, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
// clang-format on
>;
template <typename DsLayout,
typename DsDataType,
typename CDEElementwiseOp,
GemmSpecialization GemmSpec = GemmMNKPadding,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave>
using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances =
std::tuple<
// clang-format off
//###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S<C,D0...,D_N|
// Latency friendly
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
// Memory friendly
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 64, 128, 8, 4, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 32, 64, 128, 8, 4, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
// DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
// clang-format on
>;
void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
ck::Tuple<Row>,
Row,
BF16,
I8,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
Multiply>>>& instances)
{
// comp
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmDefault>{});
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmMNKPadding>{});
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmMNPadding>{});
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmKPadding>{});
// mem
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmDefault,
Intrawave>{});
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmMNKPadding,
Intrawave>{});
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmMNPadding,
Intrawave>{});
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmKPadding,
Intrawave>{});
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmDefault,
Interwave>{});
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmMNKPadding,
Interwave>{});
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmMNPadding,
Interwave>{});
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<Row>,
ck::Tuple<BF16>,
Multiply,
GemmKPadding,
Interwave>{});
}
void add_device_grouped_gemm_xdl_tile_loop_multiply_bias_fastgelu_bf16_i8_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
ck::Tuple<Row, Row>,
Row,
BF16,
I8,
ck::Tuple<BF16, BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyAddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_irregular_tile_instances<
ck::Tuple<Row, Row>,
ck::Tuple<BF16, BF16>,
MultiplyAddFastGelu>{});
}
void add_device_grouped_gemm_xdl_tile_loop_multiply_fastgelu_bf16_i8_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmTileLoop<Row,
Row,
ck::Tuple<Row>,
Row,
BF16,
I8,
ck::Tuple<BF16>,
BF16,
PassThrough,
PassThrough,
MultiplyFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_irregular_tile_instances<
ck::Tuple<Row>,
ck::Tuple<BF16>,
MultiplyFastGelu>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -17,7 +17,6 @@
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
@@ -42,11 +41,14 @@ bool profile_grouped_gemm_impl(int do_verification,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1,
int n_warmup = 1,
int n_iter = 10)
const std::vector<int>& kbatches = {},
int n_warmup = 1,
int n_iter = 10)
{
bool pass = true;
// TODO: Fixme - we do not pass compute data type here but need it
// to compute error thresholds.
using ComputeDataType = ADataType;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
@@ -75,6 +77,7 @@ bool profile_grouped_gemm_impl(int do_verification,
std::vector<Tensor<CDataType>> c_m_n_host_results;
std::vector<Tensor<CDataType>> c_m_n_device_results;
ComputeDataType max_abs_in_val = 0.f;
for(std::size_t i = 0; i < group_count; i++)
{
a_m_k.push_back(
@@ -93,17 +96,18 @@ bool profile_grouped_gemm_impl(int do_verification,
<< i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
<< "]:" << c_m_n_device_results[i].mDesc << std::endl;
}
std::size_t num_thread = 1;
switch(init_method)
{
case 0: break;
case 1:
a_m_k[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}, num_thread);
b_k_n[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-2.f, 2.f}(a_m_k[i]);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-2.f, 2.f}(b_k_n[i]);
max_abs_in_val = 2.f;
break;
default:
a_m_k[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread);
b_k_n[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
ck::utils::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k[i]);
ck::utils::FillUniformDistribution<BDataType>{-0.5f, 0.5f}(b_k_n[i]);
max_abs_in_val = 0.5f;
}
}
@@ -164,7 +168,20 @@ bool profile_grouped_gemm_impl(int do_verification,
BElementOp,
CElementOp>;
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
// If kbatch would be bigger than 1, then we will use SplitK version.
using DeviceOpSplitK = ck::tensor_operation::device::DeviceGroupedGemmSplitK<ALayout,
BLayout,
ck::Tuple<>,
CLayout,
ADataType,
BDataType,
ck::Tuple<>,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
if(op_ptrs.size() <= 0)
@@ -205,7 +222,6 @@ bool profile_grouped_gemm_impl(int do_verification,
ref_invoker.Run(ref_argument);
}
}
// profile device GEMM instances
for(auto& gemm_ptr : op_ptrs)
{
@@ -221,43 +237,44 @@ bool profile_grouped_gemm_impl(int do_verification,
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get()));
std::size_t workspace_size = gemm_ptr->GetWorkSpaceSize(argument_ptr.get());
std::size_t kargs_size = gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get());
gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());
std::string gemm_name = gemm_ptr->GetTypeString();
DeviceMem gemm_workspace, gemm_kargs;
using DeviceOpSplitK = ck::tensor_operation::device::DeviceGroupedGemmSplitK<ALayout,
BLayout,
ck::Tuple<>,
CLayout,
ADataType,
BDataType,
ck::Tuple<>,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
// skip non-splitk grouped_gemm
if(dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get()) == nullptr)
// The following is necessary since TwoStage kernel is using additional memory both
// for Workspace and kernel arguments.
if(kargs_size > 0)
{
continue;
gemm_kargs.Realloc(kargs_size);
gemm_ptr->SetDeviceKernelArgs(argument_ptr.get(), gemm_kargs.GetDeviceBuffer());
}
if(workspace_size > 0 && workspace_size != kargs_size)
{
gemm_workspace.Realloc(workspace_size);
gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_workspace.GetDeviceBuffer());
}
std::string gemm_name = gemm_ptr->GetTypeString();
std::vector<int> kbatch_list = {1, 2, 4, 8, 12, 16, 20, 24, 32, 48, 64};
if(kbatch > 0)
// If the user will provide not empty kbatches list, then we test predefined set of kbatch
// values.
if(!kbatches.empty())
{
kbatch_list = {kbatch};
kbatch_list = kbatches;
}
for(std::size_t j = 0; j < kbatch_list.size(); j++)
{
auto kbatch_curr = kbatch_list[j];
dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get())
->SetKBatchSize(argument_ptr.get(), kbatch_curr);
if(kbatch_curr > 1 && dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get()) != nullptr)
{
dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get())
->SetKBatchSize(argument_ptr.get(), kbatch_curr);
}
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{
@@ -272,23 +289,18 @@ bool profile_grouped_gemm_impl(int do_verification,
bool instance_pass = true;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data());
auto atol = ck::utils::get_absolute_threshold<ComputeDataType, CDataType>(
max_abs_in_val, gemm_descs[i].K_);
auto rtol = ck::utils::get_relative_threshold<ComputeDataType, CDataType>(
gemm_descs[i].K_);
if(std::is_same_v<CDataType, ck::half_t> && kbatch_curr > 1)
{
instance_pass =
instance_pass && ck::utils::check_err(c_m_n_device_results[i],
c_m_n_host_results[i],
"Error: Incorrect results!",
0.06);
}
else
{
instance_pass =
instance_pass && ck::utils::check_err(c_m_n_device_results[i],
c_m_n_host_results[i]);
}
instance_pass =
instance_pass && ck::utils::check_err(c_m_n_device_results[i],
c_m_n_host_results[i],
"Error: Incorrect results!",
rtol,
atol);
if(do_log)
{
@@ -311,11 +323,12 @@ bool profile_grouped_gemm_impl(int do_verification,
pass = pass && instance_pass;
}
float ave_time = invoker_ptr->Run(
argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter});
if(time_kernel)
{
float ave_time =
invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter});
std::size_t flop = 0, num_btype = 0;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{

View File

@@ -143,8 +143,7 @@ bool profile_grouped_gemm_multiply_tile_loop_impl(int do_verification,
p_ds.reserve(group_count);
p_e.reserve(group_count);
using KernelArguments =
ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments<NumDTensor>;
using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument<NumDTensor>;
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<KernelArguments> gemm_kargs;

View File

@@ -127,7 +127,7 @@ bool profile_grouped_gemm_tile_loop_impl(int do_verification,
p_b.reserve(group_count);
p_c.reserve(group_count);
using KernelArguments = ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments<>;
using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument<>;
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<KernelArguments> gemm_kargs;

View File

@@ -1,367 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace ck {
namespace profiler {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout>
bool profile_grouped_gemm_two_stage_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1,
int n_warmup = 1,
int n_iter = 10)
{
bool pass = true;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
std::size_t group_count = Ms.size();
if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() &&
group_count == StrideBs.size() && group_count == StrideCs.size()))
{
throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n");
}
std::vector<Tensor<ADataType>> a_m_k;
std::vector<Tensor<BDataType>> b_k_n;
std::vector<Tensor<CDataType>> c_m_n_host_results;
std::vector<Tensor<CDataType>> c_m_n_device_results;
for(std::size_t i = 0; i < group_count; i++)
{
a_m_k.push_back(
Tensor<ADataType>(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{})));
b_k_n.push_back(
Tensor<BDataType>(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{})));
c_m_n_device_results.push_back(
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
c_m_n_host_results.push_back(
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n["
<< i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
<< "]:" << c_m_n_device_results[i].mDesc << std::endl;
}
std::size_t num_thread = 1;
switch(init_method)
{
case 0: break;
case 1:
a_m_k[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}, num_thread);
b_k_n[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
break;
default:
a_m_k[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread);
b_k_n[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
}
}
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_device_buf, b_device_buf, c_device_buf;
a_device_buf.reserve(group_count);
b_device_buf.reserve(group_count);
c_device_buf.reserve(group_count);
std::vector<const void*> p_a, p_b;
std::vector<void*> p_c;
p_a.reserve(group_count);
p_b.reserve(group_count);
p_c.reserve(group_count);
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
gemm_descs.reserve(group_count);
for(std::size_t i = 0; i < group_count; i++)
{
a_device_buf.emplace_back(
std::make_unique<DeviceMem>(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize()));
b_device_buf.emplace_back(
std::make_unique<DeviceMem>(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize()));
c_device_buf.emplace_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpaceSize()));
a_device_buf[i]->ToDevice(a_m_k[i].mData.data());
b_device_buf[i]->ToDevice(b_k_n[i].mData.data());
gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
p_a.push_back(a_device_buf[i]->GetDeviceBuffer());
p_b.push_back(b_device_buf[i]->GetDeviceBuffer());
p_c.push_back(c_device_buf[i]->GetDeviceBuffer());
}
using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemm<ALayout,
BLayout,
ck::Tuple<>,
CLayout,
ADataType,
BDataType,
ck::Tuple<>,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
if(op_ptrs.size() <= 0)
{
throw std::runtime_error("wrong! no device GEMM instance found");
}
std::string best_gemm_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
float best_kbatch = 0;
auto p_ds = std::vector<std::array<const void*, 0>>{};
if(do_verification)
{
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_m_k[i],
b_k_n[i],
c_m_n_host_results[i],
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
}
}
// profile device GEMM instances
for(auto& gemm_ptr : op_ptrs)
{
auto argument_ptr =
gemm_ptr->MakeArgumentPointer(p_a,
p_b,
p_ds,
p_c,
gemm_descs,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{});
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get()));
gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());
std::string gemm_name = gemm_ptr->GetTypeString();
using DeviceOpSplitK =
ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitK<ALayout,
BLayout,
ck::Tuple<>,
CLayout,
ADataType,
BDataType,
ck::Tuple<>,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
// skip non-splitk grouped_gemm
if(dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get()) == nullptr)
{
continue;
}
std::vector<int> kbatch_list = {1, 2, 4, 8, 12, 16, 20, 24, 32, 48, 64};
if(kbatch > 0)
{
kbatch_list = {kbatch};
}
for(std::size_t j = 0; j < kbatch_list.size(); j++)
{
auto kbatch_curr = kbatch_list[j];
dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get())
->SetKBatchSize(argument_ptr.get(), kbatch_curr);
DeviceMem gemm_arg_dev_mem(dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get())
->GetDeviceKernelArgSize(argument_ptr.get()));
dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get())
->SetDeviceKernelArgs(argument_ptr.get(), gemm_arg_dev_mem.GetDeviceBuffer());
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{
gemm_desc_workspace.SetZero();
for(std::size_t i = 0; i < gemm_descs.size(); i++)
c_device_buf[i]->SetZero();
invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr, false, 0, n_warmup, n_iter});
if(do_verification)
{
bool instance_pass = true;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data());
if(std::is_same_v<CDataType, ck::half_t> && kbatch_curr > 1)
{
instance_pass =
instance_pass && ck::utils::check_err(c_m_n_device_results[i],
c_m_n_host_results[i],
"Error: Incorrect results!",
0.06);
}
else
{
instance_pass =
instance_pass && ck::utils::check_err(c_m_n_device_results[i],
c_m_n_host_results[i]);
}
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a_m_k[i].mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n[i].mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_device: ", c_m_n_device_results[i].mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_host : ", c_m_n_host_results[i].mData, ",")
<< std::endl;
}
}
std::cout << "Instance: " << gemm_name << " verification "
<< (instance_pass ? "SUCCEED" : "FAILED") << std::endl;
pass = pass && instance_pass;
}
float ave_time = invoker_ptr->Run(
argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter});
if(time_kernel)
{
std::size_t flop = 0, num_btype = 0;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i];
num_btype += sizeof(ADataType) * Ms[i] * Ks[i] +
sizeof(BDataType) * Ks[i] * Ns[i] +
sizeof(CDataType) * Ms[i] * Ns[i];
}
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
<< " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << ", KBatch "
<< kbatch_curr << std::endl;
if(tflops > best_tflops)
{
best_gemm_name = gemm_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
best_kbatch = kbatch_curr;
}
}
}
else
{
std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem"
<< std::endl;
}
}
}
if(time_kernel)
{
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_gemm_name << ", KBatch = " << best_kbatch
<< std::endl;
}
return pass;
}
} // namespace profiler
} // namespace ck

View File

@@ -43,7 +43,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_two_stage.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_tile_loop.cpp)
list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp)

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
@@ -39,16 +39,13 @@ namespace {
std::vector<int> argToIntArray(char* input)
{
std::vector<int> out;
std::istringstream in(input);
std::string item;
while(std::getline(in, item, ','))
{
out.push_back(std::stoi(item));
}
return out;
}
@@ -69,7 +66,7 @@ int profile_grouped_gemm(int argc, char* argv[])
<< "arg7: time kernel (0=n0, 1=yes)\n"
<< "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
"64,64 64,64 128,128)\n"
<< "arg15: kbatch value (default 1)\n"
<< "arg15: kbatch values (default 1)\n"
<< "optional:\n"
<< "arg16: number of warm-up cycles (default 1)\n"
<< "arg17: number of iterations (default 10)\n"
@@ -92,7 +89,7 @@ int profile_grouped_gemm(int argc, char* argv[])
const auto StrideAs = argToIntArray(argv[11]);
const auto StrideBs = argToIntArray(argv[12]);
const auto StrideCs = argToIntArray(argv[13]);
const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1;
const auto kbatches = argc >= 15 ? argToIntArray(argv[14]) : std::vector<int>{};
int n_warmup = 1;
int n_iter = 10;
@@ -102,7 +99,6 @@ int profile_grouped_gemm(int argc, char* argv[])
n_iter = std::stoi(argv[16]);
}
#ifdef CK_ENABLE_FP16
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_impl<ck::half_t,
@@ -121,7 +117,7 @@ int profile_grouped_gemm(int argc, char* argv[])
StrideAs,
StrideBs,
StrideCs,
kbatch,
kbatches,
n_warmup,
n_iter);
}
@@ -143,7 +139,7 @@ int profile_grouped_gemm(int argc, char* argv[])
StrideAs,
StrideBs,
StrideCs,
kbatch,
kbatches,
n_warmup,
n_iter);
}
@@ -165,7 +161,7 @@ int profile_grouped_gemm(int argc, char* argv[])
StrideAs,
StrideBs,
StrideCs,
kbatch,
kbatches,
n_warmup,
n_iter);
}
@@ -187,7 +183,7 @@ int profile_grouped_gemm(int argc, char* argv[])
StrideAs,
StrideBs,
StrideCs,
kbatch,
kbatches,
n_warmup,
n_iter);
}
@@ -209,7 +205,7 @@ int profile_grouped_gemm(int argc, char* argv[])
StrideAs,
StrideBs,
StrideCs,
kbatch,
kbatches,
n_warmup,
n_iter);
}
@@ -231,7 +227,73 @@ int profile_grouped_gemm(int argc, char* argv[])
StrideAs,
StrideBs,
StrideCs,
kbatch,
kbatches,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatches,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatches,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN)
{
ck::profiler::profile_grouped_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatches,
n_warmup,
n_iter);
}
@@ -239,7 +301,6 @@ int profile_grouped_gemm(int argc, char* argv[])
{
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");
}
#endif
return 0;
}

View File

@@ -32,9 +32,7 @@ namespace {
std::vector<int> argToIntArray(char* input)
{
std::vector<int> out;
std::istringstream in(input);
std::string item;
while(std::getline(in, item, ','))
@@ -83,7 +81,7 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
const auto StrideAs = argToIntArray(argv[11]);
const auto StrideBs = argToIntArray(argv[12]);
const auto StrideCs = argToIntArray(argv[13]);
const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1;
const int kbatch = argc >= 15 ? std::stoi(argv[14]) : 1;
using F32 = float;
using F16 = ck::half_t;
@@ -97,8 +95,8 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
int n_iter = 10;
if(argc == 17)
{
n_warmup = std::stoi(argv[16]);
n_iter = std::stoi(argv[17]);
n_warmup = std::stoi(argv[15]);
n_iter = std::stoi(argv[16]);
}
#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8)

View File

@@ -1,228 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_grouped_gemm_two_stage_impl.hpp"
#include "profiler_operation_registry.hpp"
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
};
enum struct GemmDataType
{
F16_F16_F16, // 0
BF16_INT8_BF16, // 1
BF16_BF16_BF16 // 2
};
#define OP_NAME "grouped_gemm_two_stage"
#define OP_DESC "Grouped GEMM TwoStage"
namespace {
std::vector<int> argToIntArray(char* input)
{
std::vector<int> out;
std::istringstream in(input);
std::string item;
while(std::getline(in, item, ','))
{
out.push_back(std::stoi(item));
}
return out;
}
int profile_grouped_gemm_two_stage(int argc, char* argv[])
{
if(argc < 14)
{
std::cout
<< "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
<< "arg2: data type (0: fp16; 1: bf16@int8; 2: bf16)\n"
<< "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n]);\n"
<< "arg4: verification (0: no; 1: yes)\n"
<< "arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n"
<< "arg7: time kernel (0=n0, 1=yes)\n"
<< "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
"64,64 64,64 128,128)\n"
<< "arg15: kbatch value (default 1)\n"
<< "optional:\n"
<< "arg16: number of warm-up cycles (default 1)\n"
<< "arg17: number of iterations (default 10)\n"
<< std::endl;
exit(1);
}
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]);
const bool time_kernel = std::stoi(argv[7]);
const auto Ms = argToIntArray(argv[8]);
const auto Ns = argToIntArray(argv[9]);
const auto Ks = argToIntArray(argv[10]);
auto StrideAs = argToIntArray(argv[11]);
auto StrideBs = argToIntArray(argv[12]);
auto StrideCs = argToIntArray(argv[13]);
const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1;
const int DefaultStrideA = Ks[0];
const int DefaultStrideB = Ns[0];
const int DefaultStrideC = Ns[0];
for(size_t i = 0; i < Ms.size(); ++i)
{
StrideAs[i] = StrideAs[i] == -1 ? DefaultStrideA : StrideAs[i];
StrideBs[i] = StrideBs[i] == -1 ? DefaultStrideB : StrideBs[i];
StrideCs[i] = StrideCs[i] == -1 ? DefaultStrideC : StrideCs[i];
}
int n_warmup = 1;
int n_iter = 10;
if(argc == 17)
{
n_warmup = std::stoi(argv[16]);
n_iter = std::stoi(argv[17]);
}
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_two_stage_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::BF16_INT8_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_two_stage_impl<ck::bhalf_t,
int8_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::BF16_INT8_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_two_stage_impl<ck::bhalf_t,
int8_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_two_stage_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_two_stage_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
time_kernel,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
else
{
throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented");
}
return 0;
}
} // anonymous namespace
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_grouped_gemm_two_stage);

View File

@@ -6,12 +6,6 @@ if(result EQUAL 0)
add_dependencies(test_grouped_gemm test_grouped_gemm_splitk)
endif()
add_gtest_executable(test_grouped_gemm_two_stage_splitk test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_grouped_gemm_two_stage_splitk PRIVATE utility device_grouped_gemm_instance)
add_dependencies(test_grouped_gemm test_grouped_gemm_two_stage_splitk)
endif()
add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance)

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include <vector>
@@ -10,25 +10,35 @@
#include "gtest/gtest.h"
#include "test_grouped_gemm_util.hpp"
using F16 = ck::half_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F8 = ck::f8_t;
using I8 = int8_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using RRR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Row, Row, F16, F16, F16>>;
using RCR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Col, Row, F16, F16, F16>>;
template <typename Tuple>
class TestGroupedGemm : public ck::test::TestGroupedGemm<Tuple>
{
};
using RRR_F16_F16_F16_LargeK = ck::test::TestGroupedGemm<std::tuple<Row, Row, Row, F16, F16, F16>>;
using RCR_F16_F16_F16_LargeK = ck::test::TestGroupedGemm<std::tuple<Row, Col, Row, F16, F16, F16>>;
// clang-format off
using KernelTypes = ::testing::Types<
std::tuple< Row, Row, Row, F16, F16, F16>,
std::tuple< Row, Col, Row, F16, F16, F16>,
std::tuple< Col, Row, Row, F16, F16, F16>,
std::tuple< Col, Col, Row, F16, F16, F16>,
std::tuple< Row, Row, Row, BF16, BF16, BF16>,
std::tuple< Row, Col, Row, BF16, BF16, BF16>,
std::tuple< Col, Row, Row, BF16, BF16, BF16>,
std::tuple< Row, Row, Row, BF16, I8, BF16>,
std::tuple< Row, Col, Row, BF16, I8, BF16>,
std::tuple< Row, Row, Row, F16, F8, F16>,
std::tuple< Row, Row, Row, F8, F16, F16>
>;
// clang-format on
const std::vector<int> KBATCH{1, 2, 3, 5, 8};
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_KN, RRR_F16_F16_F16, testing::ValuesIn(KBATCH));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_NK, RCR_F16_F16_F16, testing::ValuesIn(KBATCH));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_LargeK_MK_KN,
RRR_F16_F16_F16_LargeK,
testing::Values(32, 64));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_LargeK_MK_NK,
RCR_F16_F16_F16_LargeK,
testing::Values(32, 64));
TYPED_TEST_SUITE(TestGroupedGemm, KernelTypes);
#include "test_grouped_gemm_ut_cases.inc"

View File

@@ -1,6 +1,6 @@
#pragma once
TEST_P(RRR_F16_F16_F16, TinyCases)
TYPED_TEST(TestGroupedGemm, TinyCases)
{
const std::vector<int> Ms{0, 1};
constexpr int N = 768;
@@ -8,14 +8,11 @@ TEST_P(RRR_F16_F16_F16, TinyCases)
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
this->Run(Ms, Ns, Ks);
}
TEST_P(RRR_F16_F16_F16, SmallCases)
TYPED_TEST(TestGroupedGemm, SmallCases)
{
const std::vector<int> Ms{2, 1, 3, 4, 5, 0};
constexpr int N = 768;
@@ -23,14 +20,11 @@ TEST_P(RRR_F16_F16_F16, SmallCases)
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
this->Run(Ms, Ns, Ks);
}
TEST_P(RRR_F16_F16_F16, MidCases)
TYPED_TEST(TestGroupedGemm, MidCases)
{
const std::vector<int> Ms{167, 183, 177, 153, 139, 204};
constexpr int N = 768;
@@ -38,14 +32,11 @@ TEST_P(RRR_F16_F16_F16, MidCases)
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
this->Run(Ms, Ns, Ks);
}
TEST_P(RRR_F16_F16_F16, Regular)
TYPED_TEST(TestGroupedGemm, Regular)
{
const std::vector<int> Ms{64, 128, 256};
constexpr int N = 768;
@@ -53,14 +44,11 @@ TEST_P(RRR_F16_F16_F16, Regular)
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
this->Run(Ms, Ns, Ks);
}
TEST_P(RRR_F16_F16_F16, MNKPadded)
TYPED_TEST(TestGroupedGemm, MNKPadded)
{
const std::vector<int> Ms{127, 150, 188, 210};
constexpr int N = 136;
@@ -68,88 +56,11 @@ TEST_P(RRR_F16_F16_F16, MNKPadded)
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
this->Run(Ms, Ns, Ks);
}
TEST_P(RCR_F16_F16_F16, TinyCases)
{
const std::vector<int> Ms{0, 1};
constexpr int N = 768;
constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, SmallCases)
{
const std::vector<int> Ms{2, 1, 3, 4, 5, 0};
constexpr int N = 768;
constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, MidCases)
{
const std::vector<int> Ms{167, 183, 177, 153, 139, 204};
constexpr int N = 768;
constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, Regular)
{
const std::vector<int> Ms{32, 64, 128, 256};
constexpr int N = 768;
constexpr int K = 320;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16, MNKPadded)
{
const std::vector<int> Ms{127, 150, 188, 210};
constexpr int N = 136;
constexpr int K = 280;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch)
TYPED_TEST(TestGroupedGemm, TestLargeKBatch)
{
const std::vector<int> Ms{188, 210};
constexpr int N = 768;
@@ -157,24 +68,8 @@ TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch)
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_F16_F16_F16_LargeK, TestLargeKBatch)
{
const std::vector<int> Ms{188, 210};
constexpr int N = 768;
constexpr int K = 4096;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
this->k_batches_ = {32, 64};
this->Run(Ms, Ns, Ks);
}

View File

@@ -22,7 +22,6 @@
#include "ck/utility/tuple.hpp"
#include "ck/utility/number.hpp"
#include "profiler/profile_grouped_gemm_impl.hpp"
#include "profiler/profile_grouped_gemm_two_stage_impl.hpp"
namespace ck {
namespace test {
@@ -40,7 +39,7 @@ std::string serialize_range(const Range& range)
}
template <typename Tuple>
class TestGroupedGemm : public testing::TestWithParam<int>
class TestGroupedGemm : public testing::Test
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
@@ -50,23 +49,77 @@ class TestGroupedGemm : public testing::TestWithParam<int>
using BDataType = std::tuple_element_t<4, Tuple>;
using EDataType = std::tuple_element_t<5, Tuple>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
public:
static constexpr bool verify_ = true;
static constexpr int init_method_ = 1; // decimal value initialization
static constexpr int init_method_ = 1; // integer value initialization
static constexpr bool log_ = false;
static constexpr bool bench_ = false; // measure kernel performance
static constexpr int n_warmup_ = 0;
static constexpr int n_iter_ = 1;
std::vector<int> k_batches_;
void SetUp() override {}
void SetUp() override { k_batches_ = {1, 2, 3, 5, 8}; }
private:
template <typename Layout>
void SetStrides(std::vector<int>& strides,
const std::vector<int>& rows,
const std::vector<int>& cols) const
{
if(std::is_same_v<Layout, Row>)
{
for(const auto c : cols)
{
strides.emplace_back(c);
}
}
else if(std::is_same_v<Layout, Col>)
{
for(const auto r : rows)
{
strides.emplace_back(r);
}
}
}
public:
void Run(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1,
int n_warmup = 1,
int n_iter = 10)
const std::vector<int>& StrideAs = {},
const std::vector<int>& StrideBs = {},
const std::vector<int>& StrideCs = {})
{
std::vector<int> stride_as = StrideAs;
std::vector<int> stride_bs = StrideBs;
std::vector<int> stride_cs = StrideCs;
if(stride_as.empty())
{
SetStrides<ALayout>(stride_as, Ms, Ks);
}
if(stride_bs.empty())
{
SetStrides<BLayout>(stride_bs, Ks, Ns);
}
if(stride_cs.empty())
{
SetStrides<ELayout>(stride_cs, Ms, Ns);
}
RunSingle(Ms, Ns, Ks, stride_as, stride_bs, stride_cs, k_batches_);
}
void RunSingle(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
const std::vector<int>& kbatches)
{
bool pass = ck::profiler::profile_grouped_gemm_impl<ADataType,
BDataType,
@@ -84,61 +137,9 @@ class TestGroupedGemm : public testing::TestWithParam<int>
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
EXPECT_TRUE(pass);
}
};
template <typename Tuple>
class TestGroupedGemmTwoStage : public testing::TestWithParam<int>
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using ELayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using BDataType = std::tuple_element_t<4, Tuple>;
using EDataType = std::tuple_element_t<5, Tuple>;
public:
static constexpr bool verify_ = true;
static constexpr int init_method_ = 1; // decimal value initialization
static constexpr bool log_ = false;
static constexpr bool bench_ = false; // measure kernel performance
void SetUp() override {}
void Run(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1,
int n_warmup = 1,
int n_iter = 10)
{
bool pass = ck::profiler::profile_grouped_gemm_two_stage_impl<ADataType,
BDataType,
EDataType,
float,
ALayout,
BLayout,
ELayout>(verify_,
init_method_,
log_,
bench_,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
kbatches,
n_warmup_,
n_iter_);
EXPECT_TRUE(pass);
}
};
@@ -263,7 +264,7 @@ struct DeviceGroupedGemmSplitkInstanceWrapper
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
if(kbatch > 1)
{
ggemm_instance.SetKBatchSize(argument, kbatch);
ggemm_instance.SetKBatchSize(&argument, kbatch);
}
return ggemm_instance.IsSupportedArgument(argument);
@@ -300,13 +301,13 @@ struct DeviceGroupedGemmSplitkInstanceWrapper
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
if(kbatch > 1)
{
ggemm_instance.SetKBatchSize(argument, kbatch);
ggemm_instance.SetKBatchSize(&argument, kbatch);
}
EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument));
auto invoker = ggemm_instance.MakeInvoker();
DeviceMem gemm_desc_workspace(ggemm_instance.GetWorkSpaceSize(&argument));
ggemm_instance.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer());
DeviceMem dev_gemm_kargs(ggemm_instance.GetDeviceKernelArgSize(&argument));
ggemm_instance.SetDeviceKernelArgs(&argument, dev_gemm_kargs.GetDeviceBuffer());
return invoker.Run(argument, StreamConfig{nullptr, false});
}
};