mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
* Add stride validation to prevent segfault in blockscale GEMM
* run clang-format
* Update profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp
Co-authored-by: rahjain-amd <Rahul.Jain@amd.com>
* added stride length checking to more gemm examples in ckprofiler
* ran clang format
* added validation header and implement in core gemm operations
* remove ck_tile transpose and gemm stages from CI (#2646)
* update CK build instruction step 4 (#2563)
Co-authored-by: Aviral Goel <aviral.goel@amd.com>
* Fixes to "General 2D Reduction Kernel" (#2535) (#2656)
* fix reduce2d
- revret the combine_partial_results() chnages
- remove auto from function def
* clang-format
* enable aiter test_mha in daily CI (#2659)
* feat(copy_kernel): add basic copy kernel example with beginner friendly documentation (#2582)
* feat(copy_kernel): add basic copy kernel example with documentation
* docs(CHANGELOG): Updated changelog
* chore: performed clang format
* Update example/ck_tile/39_copy/copy_basic.cpp
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Update example/ck_tile/39_copy/README.md
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Update example/ck_tile/39_copy/README.md
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Update example/ck_tile/39_copy/README.md
Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>
* Update example/ck_tile/39_copy/README.md
Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>
* Update example/ck_tile/39_copy/README.md
Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>
* fix(terminology): follow amd terms
* extract elementwise copy to a new kernel
* fix(copy_kernel): bug in verification
* add comments about vgpr usage
* lint and nits
* add notes and comments
* print hostTensor via stream
* print hostTensor via stream
---------
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>
* [CK_TILE] FMHA BWD Optimization For GFX950 (#2628)
* simplify fmha_bwd_kernel MakeKargs & dq_dram_window
* simply duplicate
* trload pipeline
* Try two-stage
* add prefetch
* optimize & iglp
* Fix num_byte calculations to use nhead_k for K & V size (#2653)
Simple fix just to calculate the number of bytes correctly for what's reported in the output. I was getting 6200 GB/s which is past the SoL of MI300.
Before:
```
./bin/tile_example_fmha_fwd -prec=bf16 -b=2 -s=1 -s_k=32768 -h=32 -h_k=8 -d=128 -page_block_size=128 -num_splits=8 -iperm=0 -operm=0 -v=0 -kname=1
[bf16|batch|bshd] b:2, h:32/8, s:1/32768, d:128/128, scale_s:0.0883883, bias:n, p_drop:0, lse:0, squant:0, mask:n, v:r, num_splits:8, page_block_size:128, fmha_fwd_splitkv_d128_bf16_batch_b16x64x64x128x64x128_r1x4x1_r1x4x1_w16x16x16_w16x16x16_qr_nwarp_sshuffle_vr_ps_nlogits_nbias_nmask_lse_nsquant_pagedkv, fmha_fwd_splitkv_combine_d128_bf16_batch_b32_unused_ps_nlse_nsquant, 0.173 ms, 6.20 TFlops, 6202.95 GB/s
```
After:
```
./bin/tile_example_fmha_fwd -prec=bf16 -b=2 -s=1 -s_k=32768 -h=32 -h_k=8 -d=128 -page_block_size=128 -num_splits=8 -iperm=0 -operm=0 -v=0 -kname=1
[bf16|batch|bshd] b:2, h:32/8, s:1/32768, d:128/128, scale_s:0.0883883, bias:n, p_drop:0, lse:0, squant:0, mask:n, v:r, num_splits:8, page_block_size:128, fmha_fwd_splitkv_d128_bf16_batch_b16x64x64x128x64x128_r1x4x1_r1x4x1_w16x16x16_w16x16x16_qr_nwarp_sshuffle_vr_ps_nlogits_nbias_nmask_lse_nsquant_pagedkv, fmha_fwd_splitkv_combine_d128_bf16_batch_b32_unused_ps_nlse_nsquant, 0.163 ms, 6.58 TFlops, 1644.53 GB/s
```
* [CK_TILE] FMHA BWD Decode Pipeline (#2643)
* Fix distr
* Duplicate block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr
* decode 16x16 o2
* fix (#2668)
* Optimize fmha fwd decode & prefill for gfx950 (#2641)
* Fix for fwd/bwd kernel build filter
* fix bwd code
* save an example for __bf16 type
* temp save, waiting for debug
* tempsave, fmha_decode
* temp save, change all instance to 1wave
* fix async copytest bug
* Add block_sync_lds_direct_load utility
* fix the s_waitcnt_imm calculation
* Improve s_waitcnt_imm calculation
* fix vmcnt shift
* add input validation and bug fix
* remove unnecessary output
* move test_copy into test
* temp save
* tempsave
* compile pass
* tempsave, trload+asyncload done
* tempsave. asynccopy+trload sanity checked
* remove unnecessary features
* fix the lds alignment caused performance regression
* enable prefill overload operator().
* remove all lds bankconflict with xor layouts
* enable larger tile size; upgrade xor pattern
* upgrade prefill pipeline; simple iglp; consistent data produce and consume order
* small refactor
* Load Q through lds, implement xor;
* add vmcnt guard before load ktile
* Add v_permlaneb32 for block_reduce. Disable it as it will cause un-coexecutable packed math in FA
* Add XOR fold strategy for hdim<128, but perf dropped; disable it by default; wait further perf debug
* add __restrict__ to tr load
* merge fa_decode pipeline into fmha_fwd api
* remove unnecessary files; rename some files
* Remove unnecessary changes
* bug fix, clang format;
* remove non-necessary change
* fix clangformat with 18.1.3
* fix bugs
* fix bug
* fix bug on non-gfx950
* fix bugs in gemm
* fix bug in pki4
* tempsave, update the blocksync functions
* change the warp setting for hdim32 fmha fwd
* clang format
* fix conflict. disable all v-col instance for fmha fwd
* Fix the bug
* clang format
---------
Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com>
* Revert "Optimize fmha fwd decode & prefill for gfx950 (#2641)" (#2670)
This reverts commit b7322a521a.
* added batch stride checking to batched gemm ops in profiler
* removed batch stride validation
* removed batched stride validation again
* Update include/ck/library/utility/profiler_validation_common.hpp
Co-authored-by: rahjain-amd <Rahul.Jain@amd.com>
* refactor function names
* added gemm stride checking to more profiler gemm operations
* run clang format
* add stride checkign to 01 gemm example
* rename from profiler to validation common, used for examples and profiler
* build of ckProfiler success
* update file headers
---------
Co-authored-by: rahjain-amd <Rahul.Jain@amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: geozhai <44495440+geozhai@users.noreply.github.com>
Co-authored-by: Aviral Goel <aviral.goel@amd.com>
Co-authored-by: Yashvardhan Agarwal <yashagar@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>
Co-authored-by: Yi DING <yi.ding@amd.com>
Co-authored-by: Cameron Shinn <camerontshinn@gmail.com>
Co-authored-by: Mateusz Ozga <110818320+mozga-amd@users.noreply.github.com>
Co-authored-by: Haocong WANG <haocwang@amd.com>
Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com>
Co-authored-by: asleepzzz <hanwen.chang@amd.com>
205 lines
7.4 KiB
C++
205 lines
7.4 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#pragma once
|
|
|
|
template <typename ProblemType>
|
|
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
|
{
|
|
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
|
|
static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
|
|
#endif
|
|
|
|
using namespace ck::literals;
|
|
|
|
auto M = problem_size.M;
|
|
auto N = problem_size.N;
|
|
auto K = problem_size.K;
|
|
auto StrideA = problem_size.StrideA;
|
|
auto StrideB = problem_size.StrideB;
|
|
auto StrideC = problem_size.StrideC;
|
|
auto KBatch = problem_size.KBatch;
|
|
|
|
auto f_host_tensor_descriptor =
|
|
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
|
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
|
{
|
|
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
|
}
|
|
else
|
|
{
|
|
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
|
}
|
|
};
|
|
|
|
auto f_get_default_stride =
|
|
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
|
|
if(stride == -1 || stride == 0)
|
|
{
|
|
// give a chance if stride is -1, return a default packed stride
|
|
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
|
{
|
|
return static_cast<std::size_t>(col);
|
|
}
|
|
else
|
|
{
|
|
return static_cast<std::size_t>(row);
|
|
}
|
|
}
|
|
else
|
|
return static_cast<std::size_t>(stride);
|
|
};
|
|
|
|
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
|
|
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
|
|
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
|
|
|
|
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
|
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
|
|
|
switch(config.init_method)
|
|
{
|
|
case 0:
|
|
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
|
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
|
break;
|
|
case 1:
|
|
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
|
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
|
break;
|
|
case 2:
|
|
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
|
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
|
break;
|
|
case 3:
|
|
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
|
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
|
break;
|
|
default:
|
|
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
|
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
|
}
|
|
|
|
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
|
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
|
|
|
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
|
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
|
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
|
|
|
#ifdef BUILD_INT4_EXAMPLE
|
|
DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
|
DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
|
DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) *
|
|
c_m_n_device_result.mDesc.GetElementSpaceSize());
|
|
|
|
const Tensor<KernelADataType> a_m_k_converted(a_m_k);
|
|
const Tensor<KernelBDataType> b_k_n_converted(b_k_n);
|
|
|
|
a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data());
|
|
b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data());
|
|
#else
|
|
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
|
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
|
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
|
|
|
|
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
|
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
|
#endif
|
|
DeviceMem workspace;
|
|
|
|
auto a_element_op = AElementOp{};
|
|
auto b_element_op = BElementOp{};
|
|
auto c_element_op = CElementOp{};
|
|
|
|
// do GEMM
|
|
auto gemm = DeviceGemmV2Instance{};
|
|
auto invoker = gemm.MakeInvoker();
|
|
float ave_time = 0;
|
|
|
|
auto argument = gemm.MakeArgument(
|
|
#ifdef BUILD_INT4_EXAMPLE
|
|
static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
|
static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
|
static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
|
#else
|
|
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
|
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
|
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
|
#endif
|
|
M,
|
|
N,
|
|
K,
|
|
StrideA,
|
|
StrideB,
|
|
StrideC,
|
|
KBatch,
|
|
a_element_op,
|
|
b_element_op,
|
|
c_element_op);
|
|
|
|
if(!gemm.IsSupportedArgument(argument))
|
|
{
|
|
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
|
|
|
return true;
|
|
}
|
|
|
|
bool pass = true;
|
|
if((config.do_verification == 1) || (config.do_verification == 3))
|
|
{
|
|
auto ref_gemm = ReferenceGemmInstance{};
|
|
auto ref_invoker = ref_gemm.MakeInvoker();
|
|
|
|
auto ref_argument = ref_gemm.MakeArgument(
|
|
a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
|
|
|
|
ref_invoker.Run(ref_argument);
|
|
|
|
ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 1});
|
|
#ifdef BUILD_INT4_EXAMPLE
|
|
Tensor<CDataType> c_m_n_device_result_converted(c_m_n_host_result.mDesc);
|
|
|
|
c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data());
|
|
|
|
c_m_n_device_result = c_m_n_device_result_converted.CopyAsType<CDataType>();
|
|
|
|
return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result);
|
|
#else
|
|
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
|
|
|
pass &= ck::utils::check_err(c_m_n_device_result,
|
|
c_m_n_host_result,
|
|
"Error: Incorrect results!",
|
|
get_rtol<CDataType>(),
|
|
get_atol<CDataType>());
|
|
#endif
|
|
}
|
|
|
|
if(config.time_kernel)
|
|
{
|
|
ave_time =
|
|
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 50, 100, true, 4});
|
|
|
|
std::size_t flop = 2_uz * M * N * K;
|
|
std::size_t num_btype =
|
|
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
|
|
|
|
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
|
|
|
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
|
|
|
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
|
<< " GB/s, " << gemm.GetTypeString() << std::endl;
|
|
}
|
|
return pass;
|
|
}
|
|
|
|
bool run_gemm_splitk_example(int argc, char* argv[])
|
|
{
|
|
ProblemSizeSplitK problem_size;
|
|
ExecutionConfig config;
|
|
|
|
return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config);
|
|
}
|