mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Extend Grouped GEMM with MultiD (Single & Double Shared Memory) feature to use persistent kernel option (#2933)
* feat(grouped_gemm_multi_d): add new example that integrates grouped_gemm and multi_d_gemm feature
* refactor: grouped_gemm_multi_d relies on grouped_gemm_kernel
* tests(grouped_gemm): grouped_gemm test suite passes with minor adjustments
* fix: segfault fix by passing correct parameters for d tensors
* style: clang format
* WIP: host code for grouped_gemm_multi_d persistent kernel compiles but segfaults
* feat(grouped_gemm_multi_d): add functionality to run persistant kernel
* feat(grouped_gemm_multi_d): add new example that integrates grouped_gemm and multi_d_gemm feature
* refactor: grouped_gemm_multi_d relies on grouped_gemm_kernel
* tests(grouped_gemm): grouped_gemm test suite passes with minor adjustments
* fix: segfault fix by passing correct parameters for d tensors
* style: clang format
* fix: incorrect validation method and Dtensor layout in test suite
* docs: improved README text based on review comments
* fix: parameterize NumDTensor in GroupedGemmHostArgs and remove lint
[ROCm/composable_kernel commit: bebf0e9d15]
This commit is contained in:
@@ -10,16 +10,15 @@ The grouped GEMM examples include two advanced optimization features:
|
||||
Weight preshuffle is an optimization technique that reorganizes the B matrix (weights) in memory to improve data access patterns and reduce memory bandwidth requirements. This is particularly beneficial for inference workloads where the same weights are reused across multiple batches.
|
||||
|
||||
- **Implementation**: Available in `grouped_gemm_preshuffle.cpp`
|
||||
- **Configuration**: Uses `GemmConfigPreshuffleDecode` template configuration
|
||||
- **Configuration**: Uses `GemmConfigPreshuffleDecode` and `GemmConfigPreshufflePrefill` template configuration
|
||||
- **Constraints**: Currently supports only A(Row major) + B(Column major) → C(Row major) layouts
|
||||
- **Benefits**: Improved memory efficiency and reduced data movement
|
||||
|
||||
|
||||
#### Persistence Mode
|
||||
Persistence mode is a GPU optimization where thread blocks remain active on the compute units to process multiple work items sequentially, reducing kernel launch overhead and improving occupancy.
|
||||
|
||||
- **Template Parameter**: Controlled by the `Persistent` boolean template parameter in `invoke_gemm`
|
||||
- **Usage**: `invoke_gemm<ALayout, BLayout, CLayout, true>` enables persistence
|
||||
- **Benefits**: Reduced kernel launch overhead, better resource utilization for small matrix sizes
|
||||
|
||||
#### Multi-D Operations
|
||||
Multi-D operations extend the standard GEMM operation by supporting additional element-wise operations on the result tensor. This feature is particularly useful for workloads that require post-processing of the GEMM output.
|
||||
@@ -31,7 +30,8 @@ Multi-D operations extend the standard GEMM operation by supporting additional e
|
||||
- **Benefits**: Enables complex operations like scaling, activation functions, or other element-wise transformations in a single kernel call
|
||||
- **Build Target**: `make tile_example_grouped_gemm_multi_d -j`
|
||||
|
||||
Both features can be combined with different data types (fp16, fp8) and layout configurations to optimize performance for specific workloads.
|
||||
Multi-D operations supports both persistence and non-persistence modes.
|
||||
Weight preshuffle supports only on non-persistence mode.
|
||||
|
||||
## Build
|
||||
```
|
||||
@@ -48,7 +48,7 @@ make tile_example_grouped_gemm_multi_d -j
|
||||
# The quant grouped gemm fp8 example
|
||||
make tile_example_quant_grouped_gemm -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_grouped_gemm`, `build/bin/tile_example_grouped_gemm_preshuffle`, `build/bin/tile_example_grouped_gemm_multi_d`, and `build/bin/tile_example_quant_grouped_gemm`.
|
||||
Each example will result in an corresponding executable `build/bin/tile_example_grouped_gemm`, `build/bin/tile_example_grouped_gemm_preshuffle`, `build/bin/tile_example_grouped_gemm_multi_d`, and `build/bin/tile_example_quant_grouped_gemm`.
|
||||
|
||||
|
||||
## example
|
||||
|
||||
@@ -166,6 +166,112 @@ float grouped_gemm_multi_d(const std::vector<grouped_gemm_multi_d_kargs>& gemm_d
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename EDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename CDEElementWise>
|
||||
float grouped_gemm_multi_d_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr,
|
||||
bool splitk)
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::PersistentTileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout>;
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
// We create the GEMM pipeline without specifying hotloop or tailnumber.
|
||||
// These are automatically run inside the kernel based on the given input data.
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
if(!splitk)
|
||||
{
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#include "run_grouped_gemm_multi_d_example.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
|
||||
@@ -95,6 +95,7 @@ struct GemmConfigV3 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool Persistent = true;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
@@ -170,7 +171,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
};
|
||||
|
||||
using grouped_gemm_multi_d_kargs = ck_tile::GroupedGemmHostArgs<2>;
|
||||
using grouped_gemm_multi_d_kargs = ck_tile::GroupedGemmHostArgs<DsDataType::size()>;
|
||||
|
||||
std::pair<bool, ck_tile::ArgParser> create_args(int argc, char* argv[])
|
||||
{
|
||||
@@ -201,7 +202,7 @@ std::pair<bool, ck_tile::ArgParser> create_args(int argc, char* argv[])
|
||||
|
||||
inline std::size_t get_workspace_size(const std::vector<grouped_gemm_multi_d_kargs>& gemm_descs)
|
||||
{
|
||||
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<2>);
|
||||
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<DsDataType::size()>);
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
|
||||
@@ -86,9 +86,43 @@ float invoke_gemm(int n_warmup,
|
||||
}
|
||||
else
|
||||
{
|
||||
(void)group_count;
|
||||
// not supported yet
|
||||
throw std::runtime_error("Persistent grouped gemm multiple-d is not supported yet");
|
||||
std::vector<ck_tile::GemmTransKernelArg<DsDataType::size()>> kargs;
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
const bool splitk = args[0].k_batch > 1;
|
||||
for(const auto& arg : args)
|
||||
{
|
||||
kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<1, 1, 2>{{arg.a_ptr},
|
||||
{arg.b_ptr},
|
||||
arg.ds_ptr,
|
||||
arg.e_ptr,
|
||||
arg.M,
|
||||
arg.N,
|
||||
arg.K,
|
||||
{arg.stride_A},
|
||||
{arg.stride_B},
|
||||
arg.stride_Ds,
|
||||
arg.stride_E,
|
||||
arg.k_batch});
|
||||
}
|
||||
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(
|
||||
kargs_ptr,
|
||||
kargs.data(),
|
||||
kargs.size() * sizeof(ck_tile::GemmTransKernelArg<DsDataType::size()>),
|
||||
hipMemcpyHostToDevice,
|
||||
stream.stream_id_));
|
||||
ave_time =
|
||||
grouped_gemm_multi_d_tileloop<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise>(stream, group_count, kargs_ptr, splitk);
|
||||
}
|
||||
return ave_time;
|
||||
}
|
||||
@@ -322,12 +356,6 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc,
|
||||
b_k_n_tensors[i],
|
||||
{d0_m_n_tensors[i], d1_m_n_tensors[i]},
|
||||
e_m_n_host_refs[i]);
|
||||
std::cout << "e_m_n_host_refs[i]: " << std::endl;
|
||||
e_m_n_host_refs[i].print_first_n(std::cout, 10);
|
||||
std::cout << std::endl;
|
||||
std::cout << "e_m_n_tensors[i]: " << std::endl;
|
||||
e_m_n_tensors[i].print_first_n(std::cout, 10);
|
||||
std::cout << std::endl;
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(e_m_n_host_refs[i].mData.begin(), e_m_n_host_refs[i].mData.end());
|
||||
|
||||
@@ -324,10 +324,18 @@ struct GroupedGemmKernel
|
||||
}
|
||||
else // SingleSmemBuffer
|
||||
{
|
||||
|
||||
if constexpr(UsePersistentKernel)
|
||||
{
|
||||
RunGemmWithPipelineSelection(
|
||||
a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
|
||||
RunGemmWithPipelineSelection(a_ptr,
|
||||
b_ptr,
|
||||
kargs.ds_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
else // Non-persistent kernel
|
||||
{
|
||||
@@ -365,6 +373,7 @@ struct GroupedGemmKernel
|
||||
CK_TILE_DEVICE static void
|
||||
RunGemmWithPipelineSelection(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
const std::array<const void*, NumDTensor_>& ds_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
const UniversalGemmKernelArgs<1, 1, NumDTensor_>& kargs,
|
||||
@@ -375,7 +384,7 @@ struct GroupedGemmKernel
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset.splitted_k);
|
||||
{a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k);
|
||||
|
||||
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
|
||||
Reference in New Issue
Block a user