mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
upgrade from clang-format-12 to clang-format-18 (#2568)
* upgrade to clang-format-18
* update to clang-format-18 in pre-commit-config
[ROCm/composable_kernel commit: 504b101da3]
This commit is contained in:
@@ -3,7 +3,7 @@ repos:
|
||||
hooks:
|
||||
- id: clang-format
|
||||
name: clang-format
|
||||
entry: clang-format-12 -i --style=file
|
||||
entry: clang-format-18 -i --style=file
|
||||
language: system
|
||||
types_or: [c++, inc]
|
||||
- id: copyright-year-checker
|
||||
|
||||
@@ -62,6 +62,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
|
||||
libzstd-dev \
|
||||
openssh-server \
|
||||
clang-format-12 \
|
||||
clang-format-18 \
|
||||
kmod && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
|
||||
4
Jenkinsfile
vendored
4
Jenkinsfile
vendored
@@ -994,7 +994,7 @@ pipeline {
|
||||
-o -iname \'*.cpp.in\' \
|
||||
-o -iname \'*.cl\' \
|
||||
| grep -v 'build/' \
|
||||
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\' && \
|
||||
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\' && \
|
||||
/cppcheck/build/bin/cppcheck ../* -v -j \$(nproc) -I ../include -I ../profiler/include -I ../library/include \
|
||||
-D CK_ENABLE_FP64 -D CK_ENABLE_FP32 -D CK_ENABLE_FP16 -D CK_ENABLE_FP8 -D CK_ENABLE_BF16 -D CK_ENABLE_BF8 -D CK_ENABLE_INT8 \
|
||||
-D __gfx908__ -D __gfx90a__ -D __gfx942__ -D __gfx1030__ -D __gfx1100__ -D __gfx1101__ -D __gfx1102__ \
|
||||
@@ -1023,7 +1023,7 @@ pipeline {
|
||||
-o -iname \'*.cpp.in\' \
|
||||
-o -iname \'*.cl\' \
|
||||
| grep -v 'build/' \
|
||||
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\'"
|
||||
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\'"
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true)
|
||||
|
||||
@@ -107,14 +107,14 @@ int execute_conv_fwd()
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{},
|
||||
{},
|
||||
out.GetDeviceBuffer(),
|
||||
in_lengths,
|
||||
in_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
out_lengths,
|
||||
out_strides,
|
||||
filter_strides,
|
||||
|
||||
@@ -130,14 +130,14 @@ int main()
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{},
|
||||
{},
|
||||
in.GetDeviceBuffer(),
|
||||
out_lengths,
|
||||
out_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
in_lengths,
|
||||
in_strides,
|
||||
filter_strides,
|
||||
|
||||
@@ -105,14 +105,14 @@ int main()
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{},
|
||||
{},
|
||||
in.GetDeviceBuffer(),
|
||||
out_lengths,
|
||||
out_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
in_lengths,
|
||||
in_strides,
|
||||
filter_strides,
|
||||
|
||||
@@ -109,14 +109,14 @@ int main()
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{},
|
||||
{},
|
||||
in.GetDeviceBuffer(),
|
||||
out_lengths,
|
||||
out_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
in_lengths,
|
||||
in_strides,
|
||||
filter_strides,
|
||||
|
||||
@@ -111,14 +111,14 @@ int main()
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{},
|
||||
{},
|
||||
in.GetDeviceBuffer(),
|
||||
out_lengths,
|
||||
out_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
in_lengths,
|
||||
in_strides,
|
||||
filter_strides,
|
||||
|
||||
@@ -59,7 +59,7 @@ int main()
|
||||
SimpleDeviceMem y_dev_buf(sizeof(YDataType) * mn_size);
|
||||
|
||||
std::array<const void*, 2> ab_input = {a_dev_buf.GetDeviceBuffer(),
|
||||
b_dev_buf.GetDeviceBuffer()};
|
||||
b_dev_buf.GetDeviceBuffer()};
|
||||
std::vector<ck::index_t> abStride = {Stride, 1};
|
||||
std::array<std::vector<ck::index_t>, 2> abStrides = {abStride, abStride};
|
||||
|
||||
|
||||
@@ -68,15 +68,15 @@ int main(int argc, char* argv[])
|
||||
SimpleDeviceMem out(sizeof(OutDataType) * num_out_elements);
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceReduce<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
ReduceAdd,
|
||||
PassThrough,
|
||||
UnaryDivide,
|
||||
PropagateNan,
|
||||
OutputIndex>;
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
ReduceAdd,
|
||||
PassThrough,
|
||||
UnaryDivide,
|
||||
PropagateNan,
|
||||
OutputIndex>;
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
|
||||
@@ -117,14 +117,14 @@ int execute_conv_bwd_data_bilinear()
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{in.GetDeviceBuffer()},
|
||||
{in.GetDeviceBuffer()},
|
||||
in.GetDeviceBuffer(),
|
||||
out_lengths,
|
||||
out_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{in_lengths},
|
||||
{in_strides},
|
||||
{in_lengths},
|
||||
{in_strides},
|
||||
in_lengths,
|
||||
in_strides,
|
||||
filter_strides,
|
||||
|
||||
@@ -116,14 +116,14 @@ int execute_conv_bwd_data_scale()
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{},
|
||||
{},
|
||||
in.GetDeviceBuffer(),
|
||||
out_lengths,
|
||||
out_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
in_lengths,
|
||||
in_strides,
|
||||
filter_strides,
|
||||
|
||||
@@ -121,14 +121,14 @@ int execute_conv_fwd_bilinear()
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{out.GetDeviceBuffer()},
|
||||
{out.GetDeviceBuffer()},
|
||||
out.GetDeviceBuffer(),
|
||||
in_lengths,
|
||||
in_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{out_lengths},
|
||||
{out_strides},
|
||||
{out_lengths},
|
||||
{out_strides},
|
||||
out_lengths,
|
||||
out_strides,
|
||||
filter_strides,
|
||||
|
||||
@@ -222,13 +222,13 @@ bool run_grouped_conv_fwd_convscale_reduce(
|
||||
ck::tensor_operation::element_wise::Scale{scale_wei},
|
||||
{}};
|
||||
auto conv_ok = ConvolutionScale<InDataType,
|
||||
WeiDataType,
|
||||
ConvOutDataType,
|
||||
ConvElementOp,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
NumDimSpatial>(in,
|
||||
WeiDataType,
|
||||
ConvOutDataType,
|
||||
ConvElementOp,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
NumDimSpatial>(in,
|
||||
wei,
|
||||
conv_out,
|
||||
elementwise_op,
|
||||
@@ -717,15 +717,15 @@ bool TensorFullReduction(SimpleDeviceMem& tensor,
|
||||
{
|
||||
std::cout << "\nReduction of spatial dimensions:" << std::endl;
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceReduce<OutDataType,
|
||||
OutDataType,
|
||||
OutDataType,
|
||||
NumDimSpatial,
|
||||
NumDimSpatial,
|
||||
ReduceOperation,
|
||||
PassThrough,
|
||||
AccElementwiseOperation,
|
||||
true, // PropagateNan
|
||||
false>; // OutputIndex
|
||||
OutDataType,
|
||||
OutDataType,
|
||||
NumDimSpatial,
|
||||
NumDimSpatial,
|
||||
ReduceOperation,
|
||||
PassThrough,
|
||||
AccElementwiseOperation,
|
||||
true, // PropagateNan
|
||||
false>; // OutputIndex
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
|
||||
@@ -120,14 +120,14 @@ int execute_conv_fwd_scale()
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{},
|
||||
{},
|
||||
out.GetDeviceBuffer(),
|
||||
in_lengths,
|
||||
in_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
out_lengths,
|
||||
out_strides,
|
||||
filter_strides,
|
||||
|
||||
@@ -129,8 +129,8 @@ int execute_conv_fwd_scaleadd_ab()
|
||||
in_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
out_lengths,
|
||||
out_strides,
|
||||
filter_strides,
|
||||
|
||||
@@ -132,9 +132,9 @@ void PerformImageToColumnPad0(const ck::index_t G,
|
||||
ck::wrapper::size<0>(tile_shape));
|
||||
|
||||
const auto kernel = DeviceImageToColumnPad0<decltype(input_tensor_global),
|
||||
decltype(output_tensor_global),
|
||||
decltype(tile_shape),
|
||||
decltype(thread_layout)>;
|
||||
decltype(output_tensor_global),
|
||||
decltype(tile_shape),
|
||||
decltype(thread_layout)>;
|
||||
const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true},
|
||||
kernel,
|
||||
dim3(grid_size_x, grid_size_y, 1),
|
||||
|
||||
@@ -91,8 +91,9 @@ inline auto Transform(const Range& r, F f) -> std::vector<decltype(f(*r.begin())
|
||||
}
|
||||
|
||||
template <class Range1, class Range2, class F>
|
||||
inline auto Transform(const Range1& r1, const Range2& r2, F f)
|
||||
-> std::vector<decltype(f(*r1.begin(), *r2.begin()))>
|
||||
inline auto Transform(const Range1& r1,
|
||||
const Range2& r2,
|
||||
F f) -> std::vector<decltype(f(*r1.begin(), *r2.begin()))>
|
||||
{
|
||||
std::vector<decltype(f(*r1.begin(), *r2.begin()))> result;
|
||||
assert(std::distance(r1.begin(), r1.end()) == std::distance(r2.begin(), r2.end()));
|
||||
|
||||
@@ -142,12 +142,11 @@ std::vector<Operation_Conv_Fwd_Xdl_Cshuffle> Operation_Conv_Fwd_Xdl_Cshuffle::Cr
|
||||
x.A = TensorDesc{prob.ADataType, prob.ALayout};
|
||||
x.B = TensorDesc{prob.BDataType, prob.BLayout};
|
||||
x.E = TensorDesc{prob.EDataType, prob.ELayout};
|
||||
x.Ds = Transform(prob.DsLayout, prob.DsDataType, [](auto lo, auto dt) {
|
||||
return TensorDesc{dt, lo};
|
||||
});
|
||||
x.a_elem_op = prob.AElementOp;
|
||||
x.b_elem_op = prob.BElementOp;
|
||||
x.cde_elem_op = prob.CDEElementOp;
|
||||
x.Ds = Transform(
|
||||
prob.DsLayout, prob.DsDataType, [](auto lo, auto dt) { return TensorDesc{dt, lo}; });
|
||||
x.a_elem_op = prob.AElementOp;
|
||||
x.b_elem_op = prob.BElementOp;
|
||||
x.cde_elem_op = prob.CDEElementOp;
|
||||
x.update_prologue(prologue);
|
||||
x.update_epilogue(epilogue);
|
||||
result.push_back(x);
|
||||
|
||||
@@ -55,12 +55,12 @@ TEST_CASE(test_problem_kernel)
|
||||
std::cout << "Testing solution " << std::to_string(i + 1) << std::endl;
|
||||
auto&& solution = solutions[i];
|
||||
auto src = ck::host::InterpolateString(gemm_compile_check,
|
||||
{{"include", prob.GetIncludeHeader()},
|
||||
{"template", solution.ToTemplateString()},
|
||||
{"m", std::to_string(prob.M)},
|
||||
{"n", std::to_string(prob.N)},
|
||||
{"k", std::to_string(prob.K)},
|
||||
{"o", std::to_string(prob.O)}});
|
||||
{{"include", prob.GetIncludeHeader()},
|
||||
{"template", solution.ToTemplateString()},
|
||||
{"m", std::to_string(prob.M)},
|
||||
{"n", std::to_string(prob.N)},
|
||||
{"k", std::to_string(prob.K)},
|
||||
{"o", std::to_string(prob.O)}});
|
||||
auto srcs = get_headers_for_test();
|
||||
srcs.push_back({"main.cpp", src});
|
||||
rtc::compile_options options;
|
||||
|
||||
@@ -60,11 +60,11 @@ TEST_CASE(test_problem_kernel)
|
||||
std::cout << "Testing solution " << std::to_string(i + 1) << std::endl;
|
||||
auto&& solution = solutions[i];
|
||||
auto src = ck::host::InterpolateString(gemm_compile_check,
|
||||
{{"include", prob.GetIncludeHeader()},
|
||||
{"template", solution.ToTemplateString()},
|
||||
{"m", std::to_string(prob.M)},
|
||||
{"n", std::to_string(prob.N)},
|
||||
{"k", std::to_string(prob.K)}});
|
||||
{{"include", prob.GetIncludeHeader()},
|
||||
{"template", solution.ToTemplateString()},
|
||||
{"m", std::to_string(prob.M)},
|
||||
{"n", std::to_string(prob.N)},
|
||||
{"k", std::to_string(prob.K)}});
|
||||
auto srcs = get_headers_for_test();
|
||||
srcs.push_back({"main.cpp", src});
|
||||
rtc::compile_options options;
|
||||
|
||||
@@ -16,7 +16,7 @@ struct tmp_dir
|
||||
|
||||
void execute(const std::string& cmd) const;
|
||||
|
||||
tmp_dir(tmp_dir const&) = delete;
|
||||
tmp_dir(tmp_dir const&) = delete;
|
||||
tmp_dir& operator=(tmp_dir const&) = delete;
|
||||
|
||||
~tmp_dir();
|
||||
|
||||
@@ -29,4 +29,4 @@ The following prerequisites are required to build and install Composable Kernel:
|
||||
* zlib1g-dev
|
||||
* libzstd-dev
|
||||
* openssh-server
|
||||
* clang-format-12
|
||||
* clang-format-18
|
||||
|
||||
@@ -31,15 +31,10 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
|
||||
#else
|
||||
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>;
|
||||
#endif
|
||||
// clang-format on
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
|
||||
BLayout,
|
||||
|
||||
@@ -56,10 +56,10 @@ using CDataType = float;
|
||||
using AccDataType = float;
|
||||
|
||||
#endif
|
||||
// clang-format on
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, float, AElementOp, BElementOp, CElementOp>;
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, float, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
template <typename DataType>
|
||||
std::ostream& show_2d_matrix(std::ostream& os, Tensor<DataType>& matrix)
|
||||
|
||||
@@ -117,7 +117,7 @@ int reduce_blockwise_impl(bool do_verification,
|
||||
using InOutDataTypeInDevice = typename std::
|
||||
conditional<std::is_same<InOutDataType, int4_t>::value, int8_t, InOutDataType>::type;
|
||||
#else
|
||||
using InOutDataTypeInDevice = InOutDataType;
|
||||
using InOutDataTypeInDevice = InOutDataType;
|
||||
#endif
|
||||
|
||||
using DeviceReduceInstance =
|
||||
|
||||
@@ -175,15 +175,15 @@ auto run_gemm_reduce_max_xdl(ck::index_t M,
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument = device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
{},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
{r0_device_buf.GetDeviceBuffer()},
|
||||
{r0_device_buf.GetDeviceBuffer()},
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
{},
|
||||
{},
|
||||
StrideE,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
|
||||
@@ -207,7 +207,7 @@ int main(int argc, char* argv[])
|
||||
auto argument = batched_gemm.MakeArgument(a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
{},
|
||||
{},
|
||||
c_device_buf.GetDeviceBuffer(),
|
||||
p_reduces,
|
||||
M,
|
||||
@@ -216,9 +216,9 @@ int main(int argc, char* argv[])
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
{},
|
||||
{},
|
||||
gemm_element_ops,
|
||||
{},
|
||||
{},
|
||||
reduce_in_element_ops,
|
||||
reduce_out_element_ops,
|
||||
BatchCount);
|
||||
|
||||
@@ -44,9 +44,9 @@ int run_layernorm2d_fwd_example()
|
||||
{0, 1},
|
||||
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
{1},
|
||||
1e-4,
|
||||
x_dev.GetDeviceBuffer(),
|
||||
|
||||
@@ -126,10 +126,10 @@ int run(int argc, char* argv[])
|
||||
|
||||
if(i < 4)
|
||||
{
|
||||
std::cout << "a_gs_ms_ks[" << i << "]: " << a_gs_ms_ks.mDesc << ", "
|
||||
<< "b0_gs_ns_ks[" << i << "]: " << b0_gs_ns_ks.mDesc << ", "
|
||||
<< "b1_gs_os_ns[" << i << "]: " << b1_gs_os_ns.mDesc << ", "
|
||||
<< "c_gs_ms_os[" << i << "]: " << c_gs_ms_os_device_result.mDesc << std::endl;
|
||||
std::cout << "a_gs_ms_ks[" << i << "]: " << a_gs_ms_ks.mDesc << ", " << "b0_gs_ns_ks["
|
||||
<< i << "]: " << b0_gs_ns_ks.mDesc << ", " << "b1_gs_os_ns[" << i
|
||||
<< "]: " << b1_gs_os_ns.mDesc << ", " << "c_gs_ms_os[" << i
|
||||
<< "]: " << c_gs_ms_os_device_result.mDesc << std::endl;
|
||||
}
|
||||
|
||||
switch(init_method)
|
||||
|
||||
@@ -129,11 +129,11 @@ int main()
|
||||
auto argument_ptr = device_instance.MakeArgumentPointer(
|
||||
out_dev.GetDeviceBuffer(),
|
||||
{ck::type_convert<EmbType*>(emb_a_dev.GetDeviceBuffer()),
|
||||
ck::type_convert<EmbType*>(emb_b_dev.GetDeviceBuffer()),
|
||||
ck::type_convert<EmbType*>(emb_c_dev.GetDeviceBuffer())},
|
||||
ck::type_convert<EmbType*>(emb_b_dev.GetDeviceBuffer()),
|
||||
ck::type_convert<EmbType*>(emb_c_dev.GetDeviceBuffer())},
|
||||
{ck::type_convert<IndexType*>(index_a_dev.GetDeviceBuffer()),
|
||||
ck::type_convert<IndexType*>(index_b_dev.GetDeviceBuffer()),
|
||||
ck::type_convert<IndexType*>(index_c_dev.GetDeviceBuffer())},
|
||||
ck::type_convert<IndexType*>(index_b_dev.GetDeviceBuffer()),
|
||||
ck::type_convert<IndexType*>(index_c_dev.GetDeviceBuffer())},
|
||||
gamma_dev.GetDeviceBuffer(),
|
||||
beta_dev.GetDeviceBuffer(),
|
||||
current_dim,
|
||||
|
||||
@@ -249,8 +249,8 @@ inline auto to_array(Range& range) noexcept
|
||||
}
|
||||
|
||||
template <typename Axes>
|
||||
inline auto is_valid_axes(const Axes& axes)
|
||||
-> std::enable_if_t<detail::is_random_access_range_v<Axes>, bool>
|
||||
inline auto
|
||||
is_valid_axes(const Axes& axes) -> std::enable_if_t<detail::is_random_access_range_v<Axes>, bool>
|
||||
{
|
||||
using std::empty;
|
||||
if(empty(axes))
|
||||
@@ -357,10 +357,11 @@ auto extend_axes(const Problem::Axes& axes)
|
||||
}
|
||||
|
||||
template <typename Shape, typename Indices>
|
||||
auto advance_indices(const Shape& shape, Indices& indices) -> std::enable_if_t<
|
||||
detail::is_bidirectional_range_v<Shape> && detail::is_sized_range_v<Shape> &&
|
||||
detail::is_bidirectional_range_v<Indices> && detail::is_sized_range_v<Indices>,
|
||||
bool>
|
||||
auto advance_indices(const Shape& shape, Indices& indices)
|
||||
-> std::enable_if_t<
|
||||
detail::is_bidirectional_range_v<Shape> && detail::is_sized_range_v<Shape> &&
|
||||
detail::is_bidirectional_range_v<Indices> && detail::is_sized_range_v<Indices>,
|
||||
bool>
|
||||
{
|
||||
using std::size;
|
||||
if(!(is_valid_shape(shape) && is_valid_indices(shape, indices) && size(shape) == size(indices)))
|
||||
|
||||
@@ -65,9 +65,9 @@ int run_groupnorm_fwd_example(int argc, char* argv[])
|
||||
{0, 0, 0, C, 1},
|
||||
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
{1, 2, 4}, // reduction dimension: [H, W, C]
|
||||
1e-6,
|
||||
x_dev.GetDeviceBuffer(),
|
||||
|
||||
@@ -152,7 +152,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
std::array<const void*, 1> inputs = {input_dev_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 2> outputs = {output_scaled_casted_transposed_dev_buf.GetDeviceBuffer(),
|
||||
output_scaled_casted_dev_buf.GetDeviceBuffer()};
|
||||
output_scaled_casted_dev_buf.GetDeviceBuffer()};
|
||||
|
||||
std::cout << "Input: " << input.mDesc << std::endl;
|
||||
std::cout << "Scale: " << scale << std::endl;
|
||||
@@ -164,8 +164,8 @@ int main(int argc, char* argv[])
|
||||
auto launch_transpose_scale = [&]() {
|
||||
auto transposeScale = DeviceElementwisePermuteInstance{};
|
||||
auto argument = transposeScale.MakeArgumentPointer(dims,
|
||||
{in_strides},
|
||||
{out_strides, in_strides},
|
||||
{in_strides},
|
||||
{out_strides, in_strides},
|
||||
inputs,
|
||||
outputs,
|
||||
ScalePassThrough{scale});
|
||||
|
||||
@@ -213,7 +213,7 @@ int main(int argc, char* argv[])
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument = device_op.MakeArgument(
|
||||
std::array<const void*, 2>{a0_device_buf.GetDeviceBuffer(),
|
||||
a1_device_buf.GetDeviceBuffer()},
|
||||
a1_device_buf.GetDeviceBuffer()},
|
||||
std::array<const void*, 1>{b_device_buf.GetDeviceBuffer()},
|
||||
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
|
||||
@@ -194,9 +194,9 @@ int main(int argc, char* argv[])
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument = device_op.MakeArgument(
|
||||
std::array<const void*, 2>{a0_device_buf.GetDeviceBuffer(),
|
||||
a1_device_buf.GetDeviceBuffer()},
|
||||
a1_device_buf.GetDeviceBuffer()},
|
||||
std::array<const void*, 2>{b0_device_buf.GetDeviceBuffer(),
|
||||
b1_device_buf.GetDeviceBuffer()},
|
||||
b1_device_buf.GetDeviceBuffer()},
|
||||
std::array<const void*, 0>{},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
std::array<std::vector<ck::index_t>, 2>{a0_ms_ks_lengths, a1_ms_ks_lengths},
|
||||
|
||||
@@ -265,10 +265,10 @@ bool run_grouped_conv_fwd(bool do_verification,
|
||||
auto device_ew_scale = DeviceElementwiseScale{};
|
||||
auto scale_invoker = device_ew_scale.MakeInvoker();
|
||||
auto scale_argument = device_ew_scale.MakeArgument(e_g_n_k_wos_lengths,
|
||||
{e_g_n_k_wos_strides},
|
||||
{e_g_n_k_wos_strides},
|
||||
{conv_device_buf.GetDeviceBuffer()},
|
||||
{out_device_buf.GetDeviceBuffer()},
|
||||
{e_g_n_k_wos_strides},
|
||||
{e_g_n_k_wos_strides},
|
||||
{conv_device_buf.GetDeviceBuffer()},
|
||||
{out_device_buf.GetDeviceBuffer()},
|
||||
scale_convert);
|
||||
|
||||
if(!device_ew_scale.IsSupportedArgument(scale_argument))
|
||||
|
||||
@@ -46,9 +46,9 @@ int run_layernorm4d_fwd_example()
|
||||
{0, W * C, C, 1},
|
||||
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
{1, 2, 3},
|
||||
1e-4,
|
||||
x_dev.GetDeviceBuffer(),
|
||||
|
||||
@@ -357,7 +357,7 @@ int main(int argc, char* argv[])
|
||||
int n1 = n % NLane;
|
||||
|
||||
int k0 = k / (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
int k1 = tempk / KPack;
|
||||
int k2 = tempk % KPack;
|
||||
|
||||
|
||||
@@ -191,8 +191,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return base_str;
|
||||
}();
|
||||
|
||||
std::cout << "[" << prec_str << "]"
|
||||
<< " m:" << m << ", n:" << n << ", x_stride:" << x_stride
|
||||
std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride
|
||||
<< ", xr_stride:" << xr_stride << ", y_stride:" << y_stride
|
||||
<< ", yr_stride:" << yr_stride << std::flush;
|
||||
|
||||
|
||||
@@ -333,12 +333,12 @@ struct matrix_core_swizzle_kernel
|
||||
return tmp_1;
|
||||
#else
|
||||
// b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv,
|
||||
constexpr index_t kv = Alignment;
|
||||
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t kv = Alignment;
|
||||
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t waveflatten = kw * nw * kv;
|
||||
const index_t kr = a_.k / (k1 * k2);
|
||||
const index_t nr = a_.n / nw;
|
||||
const index_t kr = a_.k / (k1 * k2);
|
||||
const index_t nr = a_.n / nw;
|
||||
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
p_dst,
|
||||
make_tuple(nr, kr, waveflatten),
|
||||
@@ -387,8 +387,8 @@ struct matrix_core_swizzle_kernel
|
||||
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t waveflatten_tile = kw * nw * kv;
|
||||
constexpr index_t nr_tile = NPerBlock / nw;
|
||||
constexpr index_t kr_tile = KPerBlock / (kw * kv);
|
||||
constexpr index_t nr_tile = NPerBlock / nw;
|
||||
constexpr index_t kr_tile = KPerBlock / (kw * kv);
|
||||
return make_tile_window(dst_view,
|
||||
make_tuple(number<nr_tile>{},
|
||||
number<kr_tile>{},
|
||||
|
||||
@@ -183,8 +183,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "[" << data_type << "]"
|
||||
<< " m:" << m << ", n:" << n << ", stride:" << stride
|
||||
std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n << ", stride:" << stride
|
||||
<< ", s:" << USEModelSensitive << ", valid:" << (pass ? "y" : "n") << std::flush
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
@@ -193,8 +193,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return base_str;
|
||||
}();
|
||||
|
||||
std::cout << "[" << prec_str << "]"
|
||||
<< " m:" << m << ", n:" << n << ", x_stride:" << x_stride
|
||||
std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride
|
||||
<< ", xr_stride:" << xr_stride << ", y_stride:" << y_stride
|
||||
<< ", yr_stride:" << yr_stride << ", s:" << use_model_sensitive_rmsnorm << std::flush;
|
||||
|
||||
|
||||
@@ -105,8 +105,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
b_buf.ToDevice(b_host.data());
|
||||
gamma_buf.ToDevice(gamma_host.data());
|
||||
|
||||
std::cout << "[" << input_data_type << ", " << quantized_data_type << "]"
|
||||
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush;
|
||||
std::cout << "[" << input_data_type << ", " << quantized_data_type << "]" << " m:" << m
|
||||
<< ", n:" << n << ", stride:" << stride << std::flush;
|
||||
|
||||
add_rmsnorm2d_rdquant_fwd_traits traits{input_data_type, quantized_data_type, SaveX};
|
||||
|
||||
|
||||
@@ -256,8 +256,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "[" << data_type << "]"
|
||||
<< " m:" << m << ", n:" << n << ", stride:" << stride
|
||||
std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n << ", stride:" << stride
|
||||
<< ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
|
||||
@@ -216,10 +216,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "[" << data_type << "]"
|
||||
<< " m:" << m << ", n:" << n << ", x_stride:" << x_stride
|
||||
<< ", y_stride:" << y_stride << ", valid:" << (pass ? "y" : "n") << std::flush
|
||||
<< std::endl;
|
||||
std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n
|
||||
<< ", x_stride:" << x_stride << ", y_stride:" << y_stride
|
||||
<< ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
|
||||
@@ -93,9 +93,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
x_buf.ToDevice(x_host.data());
|
||||
smscale_buf.ToDevice(smscale_host.data());
|
||||
|
||||
std::cout << "[" << data_type << "]"
|
||||
<< " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", y_stride:" << y_stride
|
||||
<< std::flush;
|
||||
std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride
|
||||
<< ", y_stride:" << y_stride << std::flush;
|
||||
|
||||
smoothquant_traits traits{data_type};
|
||||
|
||||
|
||||
@@ -228,20 +228,26 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
moe_sorting_trait trait{
|
||||
index_prec, weight_prec, local_expert_masking, clear_inside, dispatch_policy};
|
||||
|
||||
moe_sorting_args karg
|
||||
{
|
||||
topk_ids_dev.GetDeviceBuffer(), weights_dev.GetDeviceBuffer(),
|
||||
local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer() : nullptr,
|
||||
is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr,
|
||||
sorted_ids_dev.GetDeviceBuffer(), sorted_weights_dev.GetDeviceBuffer(),
|
||||
sorted_expert_ids_dev.GetDeviceBuffer(), sorted_id_cnt_dev.GetDeviceBuffer(),
|
||||
moe_buf_bytes > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr,
|
||||
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr, tokens, unit_size,
|
||||
num_experts, topk,
|
||||
moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(),
|
||||
weights_dev.GetDeviceBuffer(),
|
||||
local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer()
|
||||
: nullptr,
|
||||
is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr,
|
||||
sorted_ids_dev.GetDeviceBuffer(),
|
||||
sorted_weights_dev.GetDeviceBuffer(),
|
||||
sorted_expert_ids_dev.GetDeviceBuffer(),
|
||||
sorted_id_cnt_dev.GetDeviceBuffer(),
|
||||
moe_buf_bytes > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr,
|
||||
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr,
|
||||
tokens,
|
||||
unit_size,
|
||||
num_experts,
|
||||
topk,
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
moe_buf_interm_dim, moe_buf_elem_bytes
|
||||
moe_buf_interm_dim,
|
||||
moe_buf_elem_bytes
|
||||
#else
|
||||
static_cast<ck_tile::long_index_t>(moe_buf_size * sizeof(float))
|
||||
static_cast<ck_tile::long_index_t>(moe_buf_size * sizeof(float))
|
||||
#endif
|
||||
};
|
||||
|
||||
|
||||
@@ -40,11 +40,11 @@
|
||||
constexpr bool local_expert_masking = local_expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
|
||||
ms_weight_type, \
|
||||
sub_token_tile, \
|
||||
sub_token_onshot, \
|
||||
local_expert_masking, \
|
||||
local_token>; \
|
||||
ms_weight_type, \
|
||||
sub_token_tile, \
|
||||
sub_token_onshot, \
|
||||
local_expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -200,11 +200,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -218,11 +218,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -236,11 +236,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -254,11 +254,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -273,11 +273,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
|
||||
@@ -124,9 +124,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
smscale_buf.ToDevice(smscale_host.data());
|
||||
topk_ids_buf.ToDevice(topk_ids_host.data());
|
||||
|
||||
std::cout << "[" << prec_i << "-" << prec_o << "]"
|
||||
<< " tokens:" << tokens << ", hidden_size:" << hidden_size << ", stride:" << stride
|
||||
<< ", experts:" << experts << ", topk:" << topk << std::flush;
|
||||
std::cout << "[" << prec_i << "-" << prec_o << "]" << " tokens:" << tokens
|
||||
<< ", hidden_size:" << hidden_size << ", stride:" << stride << ", experts:" << experts
|
||||
<< ", topk:" << topk << std::flush;
|
||||
|
||||
moe_smoothquant_traits traits{prec_i, prec_o};
|
||||
|
||||
|
||||
@@ -25,27 +25,27 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
|
||||
}();
|
||||
|
||||
auto t0 = fused_moesorting_trait{"int32", "fp32", t.local_expert_masking};
|
||||
auto a0 = fused_moesorting_args
|
||||
{
|
||||
a.topk_ids_ptr, // const void* p_topk_ids;
|
||||
a.topk_weight_ptr, // const void* p_weights;
|
||||
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
|
||||
a.local_tokens,
|
||||
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
|
||||
a.sorted_weight_ptr, // void* p_sorted_weights;
|
||||
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
|
||||
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
|
||||
a.o_ptr, // void* p_moe_buf;
|
||||
a.ws_ptr, // void* p_ws;
|
||||
a.num_tokens, // index_t tokens;
|
||||
a.block_m, // index_t unit_size;
|
||||
a.num_experts, // index_t num_experts;
|
||||
a.topk, // index_t topk;
|
||||
auto a0 = fused_moesorting_args{
|
||||
a.topk_ids_ptr, // const void* p_topk_ids;
|
||||
a.topk_weight_ptr, // const void* p_weights;
|
||||
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
|
||||
a.local_tokens,
|
||||
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
|
||||
a.sorted_weight_ptr, // void* p_sorted_weights;
|
||||
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
|
||||
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
|
||||
a.o_ptr, // void* p_moe_buf;
|
||||
a.ws_ptr, // void* p_ws;
|
||||
a.num_tokens, // index_t tokens;
|
||||
a.block_m, // index_t unit_size;
|
||||
a.num_experts, // index_t num_experts;
|
||||
a.topk, // index_t topk;
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
a.stride_token, o_data_bytes,
|
||||
a.stride_token,
|
||||
o_data_bytes,
|
||||
#else
|
||||
static_cast<ck_tile::long_index_t>(a.num_tokens) *
|
||||
a.stride_token* o_data_bytes // index_t moe_buf_bytes;
|
||||
static_cast<ck_tile::long_index_t>(a.num_tokens) * a.stride_token *
|
||||
o_data_bytes // index_t moe_buf_bytes;
|
||||
#endif
|
||||
};
|
||||
|
||||
|
||||
@@ -16,11 +16,11 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
|
||||
{
|
||||
using f_traits = ck_tile::FusedMoeGemmTraits<Ts_::GateOnly, Ts_::FusedQuant == 1, 1 /*atomic*/>;
|
||||
using f_shape = ck_tile::FusedMoeGemmShape<typename Ts_::BlockTile_0,
|
||||
typename Ts_::WarpPerBlock_0,
|
||||
typename Ts_::WarpTile_0,
|
||||
typename Ts_::BlockTile_1,
|
||||
typename Ts_::WarpPerBlock_0,
|
||||
typename Ts_::WarpTile_0>;
|
||||
typename Ts_::WarpPerBlock_0,
|
||||
typename Ts_::WarpTile_0,
|
||||
typename Ts_::BlockTile_1,
|
||||
typename Ts_::WarpPerBlock_0,
|
||||
typename Ts_::WarpTile_0>;
|
||||
|
||||
constexpr auto get_activation_ = []() {
|
||||
if constexpr(Ts_::Activation == 0)
|
||||
|
||||
@@ -40,11 +40,11 @@
|
||||
constexpr bool local_expert_masking = local_expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
|
||||
ms_weight_type, \
|
||||
sub_token_tile, \
|
||||
sub_token_onshot, \
|
||||
local_expert_masking, \
|
||||
local_token>; \
|
||||
ms_weight_type, \
|
||||
sub_token_tile, \
|
||||
sub_token_onshot, \
|
||||
local_expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -204,11 +204,11 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -222,11 +222,11 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -240,11 +240,11 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -258,11 +258,11 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -277,11 +277,11 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
|
||||
@@ -218,8 +218,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return std::string(", st:") + std::to_string(stride);
|
||||
}();
|
||||
|
||||
std::cout << "[" << api_str << "|" << prec_str << "]"
|
||||
<< " t:" << tokens;
|
||||
std::cout << "[" << api_str << "|" << prec_str << "]" << " t:" << tokens;
|
||||
|
||||
if(is_local_token)
|
||||
{
|
||||
|
||||
@@ -173,10 +173,9 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time =
|
||||
|
||||
@@ -138,10 +138,9 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time =
|
||||
|
||||
@@ -216,9 +216,9 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{}))));
|
||||
|
||||
std::cout << "gemm[" << i << "]"
|
||||
<< " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc
|
||||
<< " c_m_n: " << c_m_n_tensors[i].mDesc << std::endl;
|
||||
std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc
|
||||
<< std::endl;
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
|
||||
|
||||
@@ -170,10 +170,9 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config&
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", "
|
||||
<< grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", "
|
||||
<< blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
|
||||
@@ -161,8 +161,7 @@ int run_grouped_conv_bwd_weight_example_with_layouts(
|
||||
conv_param.conv_filter_dilations_,
|
||||
conv_param.input_left_pads_,
|
||||
conv_param.input_right_pads_);
|
||||
const ck_tile::index_t GemmK =
|
||||
weight.get_element_size() / (conv_param.G_ * conv_param.K_);
|
||||
const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(weight_host_ref.mData.begin(), weight_host_ref.mData.end());
|
||||
const auto rtol_atol =
|
||||
|
||||
@@ -87,24 +87,24 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
||||
tail_number_v>;
|
||||
using CodegenGemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3<CodegenPipelineProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
CodegenPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
transposed_warp_gemm,
|
||||
ck_tile::memory_operation_enum::set>>;
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
CodegenPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
transposed_warp_gemm,
|
||||
ck_tile::memory_operation_enum::set>>;
|
||||
using Kernel =
|
||||
ck_tile::AQuantGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
|
||||
|
||||
@@ -195,14 +195,18 @@ int run_gemm_example(int argc, char* argv[])
|
||||
}
|
||||
else if(data_type == "i4fp8")
|
||||
{
|
||||
using TypeConfig = decltype(
|
||||
GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, ck_tile::fp8_t>{});
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::fp8_t,
|
||||
float,
|
||||
ck_tile::fp8_t>{});
|
||||
return run_gemm_example_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "i4bf8")
|
||||
{
|
||||
using TypeConfig = decltype(
|
||||
GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, ck_tile::bf8_t>{});
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::bf8_t,
|
||||
float,
|
||||
ck_tile::bf8_t>{});
|
||||
return run_gemm_example_prec_type<TypeConfig, 128>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "i4f32fp8")
|
||||
|
||||
@@ -13,7 +13,7 @@ for p in sorted(Path("./").rglob("*")):
|
||||
# formatting
|
||||
for x in all_files:
|
||||
subprocess.Popen(f'dos2unix {str(x)}', shell=True)
|
||||
cmd = f'clang-format-12 -style=file -i {str(x)}'
|
||||
cmd = f'clang-format-18 -style=file -i {str(x)}'
|
||||
#for xp in x.parents:
|
||||
#print(get_file_base(x))
|
||||
subprocess.Popen(cmd, shell=True)
|
||||
|
||||
@@ -12,9 +12,8 @@ inline void hip_check_error(hipError_t x)
|
||||
if(x != hipSuccess)
|
||||
{
|
||||
std::ostringstream ss;
|
||||
ss << "HIP runtime error: " << hipGetErrorString(x) << ". "
|
||||
<< "hip_check_error.hpp"
|
||||
<< ": " << __LINE__ << "in function: " << __func__;
|
||||
ss << "HIP runtime error: " << hipGetErrorString(x) << ". " << "hip_check_error.hpp" << ": "
|
||||
<< __LINE__ << "in function: " << __func__;
|
||||
throw std::runtime_error(ss.str());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,10 +11,10 @@
|
||||
namespace ck {
|
||||
namespace ranges {
|
||||
template <typename InputRange, typename OutputIterator>
|
||||
auto copy(InputRange&& range, OutputIterator iter)
|
||||
-> decltype(std::copy(std::begin(std::forward<InputRange>(range)),
|
||||
std::end(std::forward<InputRange>(range)),
|
||||
iter))
|
||||
auto copy(InputRange&& range,
|
||||
OutputIterator iter) -> decltype(std::copy(std::begin(std::forward<InputRange>(range)),
|
||||
std::end(std::forward<InputRange>(range)),
|
||||
iter))
|
||||
{
|
||||
return std::copy(std::begin(std::forward<InputRange>(range)),
|
||||
std::end(std::forward<InputRange>(range)),
|
||||
|
||||
@@ -138,9 +138,10 @@ struct FillConstant
|
||||
}
|
||||
|
||||
template <typename ForwardRange>
|
||||
auto operator()(ForwardRange&& range) const -> std::void_t<
|
||||
decltype(std::declval<const FillConstant&>()(std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range))))>
|
||||
auto operator()(ForwardRange&& range) const
|
||||
-> std::void_t<decltype(std::declval<const FillConstant&>()(
|
||||
std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range))))>
|
||||
{
|
||||
(*this)(std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range)));
|
||||
|
||||
@@ -202,7 +202,7 @@ struct joinable_thread : std::thread
|
||||
{
|
||||
}
|
||||
|
||||
joinable_thread(joinable_thread&&) = default;
|
||||
joinable_thread(joinable_thread&&) = default;
|
||||
joinable_thread& operator=(joinable_thread&&) = default;
|
||||
|
||||
~joinable_thread()
|
||||
@@ -320,7 +320,7 @@ struct Tensor
|
||||
~Tensor() = default;
|
||||
|
||||
Tensor& operator=(const Tensor&) = default;
|
||||
Tensor& operator=(Tensor&&) = default;
|
||||
Tensor& operator=(Tensor&&) = default;
|
||||
|
||||
template <typename FromT>
|
||||
explicit Tensor(const Tensor<FromT>& other) : Tensor(other.template CopyAsType<T>())
|
||||
|
||||
@@ -108,13 +108,13 @@ struct TensorAdaptor
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfHiddenDimension()
|
||||
{
|
||||
constexpr auto all_low_dim_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); },
|
||||
LowerDimensionHiddenIdss{});
|
||||
constexpr auto all_low_dim_ids =
|
||||
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); },
|
||||
LowerDimensionHiddenIdss{});
|
||||
|
||||
constexpr auto all_up_dim_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); },
|
||||
UpperDimensionHiddenIdss{});
|
||||
constexpr auto all_up_dim_ids =
|
||||
unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); },
|
||||
UpperDimensionHiddenIdss{});
|
||||
|
||||
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
|
||||
|
||||
@@ -338,8 +338,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
|
||||
TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran];
|
||||
|
||||
// sequence in, sequence out
|
||||
constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr
|
||||
{
|
||||
constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr {
|
||||
auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1);
|
||||
|
||||
// shift hidden id so every dim id is unique
|
||||
@@ -361,8 +360,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
|
||||
});
|
||||
|
||||
return low_dim_hidden_ids_1_mod_;
|
||||
}
|
||||
();
|
||||
}();
|
||||
|
||||
return generate_sequence_v2(
|
||||
[&](auto i) constexpr { return Number<low_dim_hidden_ids_1_mod[i]>{}; },
|
||||
@@ -384,8 +382,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
|
||||
TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran];
|
||||
|
||||
// sequence in, constexpr tuple out
|
||||
constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr
|
||||
{
|
||||
constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr {
|
||||
auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1);
|
||||
|
||||
// shift hidden id
|
||||
@@ -394,8 +391,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
|
||||
});
|
||||
|
||||
return up_dim_hidden_ids_1_mod_;
|
||||
}
|
||||
();
|
||||
}();
|
||||
|
||||
// constexpr tuple to sequence
|
||||
return generate_sequence_v2(
|
||||
|
||||
@@ -365,7 +365,7 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
|
||||
Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, math::plus<index_t>{}, Number<0>{}));
|
||||
|
||||
constexpr auto up_dim_hidden_idss = generate_tuple(
|
||||
[ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr {
|
||||
[old_hidden_dim_number, up_dim_numbers_scan](auto i) constexpr {
|
||||
return
|
||||
typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
|
||||
old_hidden_dim_number + up_dim_numbers_scan[i + 1],
|
||||
@@ -374,12 +374,12 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
|
||||
Number<num_new_transform>{});
|
||||
|
||||
// new visible dimension's hidden ids
|
||||
constexpr auto unordered_new_visible_dim_hidden_ids = unpack(
|
||||
[](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
|
||||
constexpr auto unordered_new_visible_dim_hidden_ids =
|
||||
unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
|
||||
|
||||
constexpr auto new_visible_dim_unordered2ordered = unpack(
|
||||
[](auto... xs) constexpr { return merge_sequences(xs...); },
|
||||
NewUpperDimensionNewVisibleIdss{});
|
||||
constexpr auto new_visible_dim_unordered2ordered =
|
||||
unpack([](auto... xs) constexpr { return merge_sequences(xs...); },
|
||||
NewUpperDimensionNewVisibleIdss{});
|
||||
|
||||
constexpr auto new_visible_dim_hidden_ids =
|
||||
unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered);
|
||||
|
||||
@@ -94,10 +94,8 @@ struct SpaceFillingCurve
|
||||
// Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the
|
||||
// idim-th element of multidimensional index.
|
||||
// All constexpr variables have to be captured by VALUE.
|
||||
constexpr auto compute_index = [ idx_1d, access_strides ](auto idim) constexpr
|
||||
{
|
||||
constexpr auto compute_index_impl = [ idx_1d, access_strides ](auto jdim) constexpr
|
||||
{
|
||||
constexpr auto compute_index = [idx_1d, access_strides](auto idim) constexpr {
|
||||
constexpr auto compute_index_impl = [idx_1d, access_strides](auto jdim) constexpr {
|
||||
auto res = idx_1d.value;
|
||||
auto id = 0;
|
||||
|
||||
|
||||
@@ -152,7 +152,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
|
||||
|
||||
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
|
||||
__device__ static auto
|
||||
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
|
||||
@@ -93,7 +93,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
|
||||
|
||||
struct Empty
|
||||
{
|
||||
__device__ Empty(){};
|
||||
__device__ Empty() {};
|
||||
template <index_t NBuffer>
|
||||
__device__ void GlobalLoad(bool cond)
|
||||
{
|
||||
@@ -119,7 +119,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
|
||||
GridBuffer b_scale_grid_buf_)
|
||||
: b_scale_thread_copy(b_scale_thread_copy_),
|
||||
b_scale_grid_desc(b_scale_grid_desc_),
|
||||
b_scale_grid_buf(b_scale_grid_buf_){};
|
||||
b_scale_grid_buf(b_scale_grid_buf_) {};
|
||||
|
||||
static constexpr index_t num_scale_k_block = BScaleThreadDesc{}.GetLength(Number<1>{});
|
||||
static constexpr index_t num_scale_krepeat = KRepeat / num_scale_k_block;
|
||||
|
||||
@@ -96,9 +96,9 @@ template <
|
||||
index_t KPack,
|
||||
bool TransposeC = false,
|
||||
index_t AMmaKStride =
|
||||
KPack* XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops,
|
||||
KPack * XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops,
|
||||
index_t BMmaKStride =
|
||||
KPack* XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops>
|
||||
KPack * XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops>
|
||||
struct BlockwiseGemmXdlops_pipeline_v4
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -188,7 +188,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
|
||||
|
||||
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
|
||||
__device__ static auto
|
||||
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
@@ -217,7 +217,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
|
||||
|
||||
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
|
||||
__device__ static auto
|
||||
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
|
||||
@@ -153,7 +153,7 @@ struct BlockwiseGemmXdlops_pipeline_base
|
||||
|
||||
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
|
||||
__device__ static auto
|
||||
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
@@ -182,7 +182,7 @@ struct BlockwiseGemmXdlops_pipeline_base
|
||||
|
||||
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
|
||||
__device__ static auto
|
||||
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
|
||||
@@ -110,7 +110,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
|
||||
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
|
||||
__device__ static auto
|
||||
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
@@ -138,7 +138,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
|
||||
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
|
||||
__device__ static auto
|
||||
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
|
||||
@@ -114,7 +114,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
|
||||
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
|
||||
__device__ static auto
|
||||
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
@@ -143,7 +143,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
|
||||
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
|
||||
__device__ static auto
|
||||
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
@@ -667,9 +667,9 @@ template <
|
||||
index_t KPack,
|
||||
bool TransposeC = false,
|
||||
index_t AMmaKStride =
|
||||
KPack* XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops,
|
||||
KPack * XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops,
|
||||
index_t BMmaKStride =
|
||||
KPack* XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops>
|
||||
KPack * XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops>
|
||||
struct BlockwiseGemmXdlops_v2
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -742,7 +742,7 @@ struct BlockwiseGemmXdlops_v2
|
||||
|
||||
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
|
||||
__device__ static auto
|
||||
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
@@ -771,7 +771,7 @@ struct BlockwiseGemmXdlops_v2
|
||||
|
||||
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
|
||||
__device__ static auto
|
||||
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
|
||||
@@ -90,7 +90,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
|
||||
|
||||
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
|
||||
__device__ static auto
|
||||
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
|
||||
@@ -258,8 +258,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
|
||||
src_buf.template DirectCopyToLds<remove_cvref_t<decltype(dst_buf)>, ScalarPerVector>(
|
||||
dst_buf, src_offset, dst_offset, is_src_valid);
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
constexpr auto move_on_dim = [&]() constexpr {
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
@@ -271,8 +270,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
}();
|
||||
|
||||
// Decide whether to move forward or backward.
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
|
||||
@@ -281,8 +281,7 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
|
||||
src_buf.template DirectCopyToLds<remove_cvref_t<decltype(dst_buf)>, ScalarPerVector>(
|
||||
dst_buf, src_offset, dst_offset, true);
|
||||
|
||||
constexpr auto move_src_on_dim = [&]() constexpr
|
||||
{
|
||||
constexpr auto move_src_on_dim = [&]() constexpr {
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
@@ -295,11 +294,9 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
}();
|
||||
|
||||
constexpr auto move_dst_on_dim = [&]() constexpr
|
||||
{
|
||||
constexpr auto move_dst_on_dim = [&]() constexpr {
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
@@ -311,8 +308,7 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
}();
|
||||
|
||||
// Decide whether to move forward or backward.
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
|
||||
@@ -49,8 +49,8 @@ namespace device {
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
struct BaseArgument
|
||||
{
|
||||
BaseArgument() = default;
|
||||
BaseArgument(const BaseArgument&) = default;
|
||||
BaseArgument() = default;
|
||||
BaseArgument(const BaseArgument&) = default;
|
||||
BaseArgument& operator=(const BaseArgument&) = default;
|
||||
|
||||
virtual ~BaseArgument() {}
|
||||
@@ -60,8 +60,8 @@ struct BaseArgument
|
||||
|
||||
struct BaseInvoker
|
||||
{
|
||||
BaseInvoker() = default;
|
||||
BaseInvoker(const BaseInvoker&) = default;
|
||||
BaseInvoker() = default;
|
||||
BaseInvoker(const BaseInvoker&) = default;
|
||||
BaseInvoker& operator=(const BaseInvoker&) = default;
|
||||
|
||||
virtual float Run(const BaseArgument*, const StreamConfig& = StreamConfig{})
|
||||
@@ -75,8 +75,8 @@ struct BaseInvoker
|
||||
|
||||
struct BaseOperator
|
||||
{
|
||||
BaseOperator() = default;
|
||||
BaseOperator(const BaseOperator&) = default;
|
||||
BaseOperator() = default;
|
||||
BaseOperator(const BaseOperator&) = default;
|
||||
BaseOperator& operator=(const BaseOperator&) = default;
|
||||
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
|
||||
virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
|
||||
|
||||
@@ -70,15 +70,9 @@ struct GroupedGemmKernelArgument
|
||||
for(auto sd : StrideDs)
|
||||
str << sd << ",";
|
||||
|
||||
std::cout << "arg {"
|
||||
<< "M:" << M << ", "
|
||||
<< "N:" << N << ", "
|
||||
<< "K:" << K << ", "
|
||||
<< "SA:" << StrideA << ", "
|
||||
<< "SB:" << StrideB << ", "
|
||||
<< "SE:" << StrideE << ", "
|
||||
<< "SDs: {" << str.str() << "}"
|
||||
<< "}" << std::endl;
|
||||
std::cout << "arg {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
|
||||
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SE:" << StrideE
|
||||
<< ", " << "SDs: {" << str.str() << "}" << "}" << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -205,25 +205,25 @@ template <typename GridwiseGemm,
|
||||
bool isMultiB>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle(
|
||||
AsPointer p_as_grid,
|
||||
BsPointer p_bs_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const index_t batch_count,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
const Block2ETileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle(
|
||||
AsPointer p_as_grid,
|
||||
BsPointer p_bs_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const index_t batch_count,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
const Block2ETileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
|
||||
device_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
|
||||
|
||||
@@ -36,25 +36,25 @@ template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_contraction_multiple_d_xdl_cshuffle(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatDsPointer p_ds_grid,
|
||||
FloatE* __restrict__ p_e_grid,
|
||||
const index_t batch_count,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
kernel_contraction_multiple_d_xdl_cshuffle(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatDsPointer p_ds_grid,
|
||||
FloatE* __restrict__ p_e_grid,
|
||||
const index_t batch_count,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
@@ -58,21 +58,21 @@ template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_e_permute_xdl(const ABDataType* __restrict__ p_a_grid,
|
||||
const ABDataType* __restrict__ p_b_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const index_t batch_count,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
kernel_batched_gemm_e_permute_xdl(const ABDataType* __restrict__ p_a_grid,
|
||||
const ABDataType* __restrict__ p_b_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const index_t batch_count,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
const index_t num_blocks_per_batch =
|
||||
|
||||
@@ -39,26 +39,25 @@ template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_gemm_xdl_cshuffle_v1(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
const FloatAB* __restrict__ p_b1_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const AccElementwiseOperation acc_element_op,
|
||||
const B1ElementwiseOperation b1_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const index_t batch_count,
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
|
||||
kernel_gemm_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
const FloatAB* __restrict__ p_b1_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const AccElementwiseOperation acc_element_op,
|
||||
const B1ElementwiseOperation b1_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const index_t batch_count,
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
@@ -63,24 +63,24 @@ template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_xdl(const ABDataType* __restrict__ p_a_grid,
|
||||
const ABDataType* __restrict__ p_b_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const index_t batch_count,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
kernel_batched_gemm_xdl(const ABDataType* __restrict__ p_a_grid,
|
||||
const ABDataType* __restrict__ p_b_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const index_t batch_count,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
|
||||
@@ -52,23 +52,23 @@ template <typename GridwiseGemm,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_dl_multiple_d(
|
||||
const ABDataType* __restrict__ p_a_grid,
|
||||
const ABDataType* __restrict__ p_b_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const index_t batch_count,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
|
||||
const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
|
||||
const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11,
|
||||
const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
kernel_gemm_dl_multiple_d(
|
||||
const ABDataType* __restrict__ p_a_grid,
|
||||
const ABDataType* __restrict__ p_b_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const index_t batch_count,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
|
||||
const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
|
||||
const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11,
|
||||
const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
|
||||
|
||||
@@ -42,32 +42,32 @@ template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_gemm_xdl_cshuffle_v1(
|
||||
const A0B0B1DataType* __restrict__ p_a0_grid,
|
||||
const A0B0B1DataType* __restrict__ p_b0_grid,
|
||||
D0sPointer p_d0s_grid,
|
||||
const A0B0B1DataType* __restrict__ p_b1_grid,
|
||||
D1sPointer p_d1s_grid,
|
||||
E1DataType* __restrict__ p_e1_grid,
|
||||
const A0ElementwiseOperation a0_element_op,
|
||||
const B0ElementwiseOperation b0_element_op,
|
||||
const CDE0ElementwiseOperation cde0_element_op,
|
||||
const B1ElementwiseOperation b1_element_op,
|
||||
const CDE1ElementwiseOperation cde1_element_op,
|
||||
const A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1,
|
||||
const B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1,
|
||||
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
|
||||
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
|
||||
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
|
||||
const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
d1s_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e1_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2E1TileMap block_2_e1tile_map,
|
||||
const index_t batch_count,
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
|
||||
kernel_batched_gemm_gemm_xdl_cshuffle_v1(
|
||||
const A0B0B1DataType* __restrict__ p_a0_grid,
|
||||
const A0B0B1DataType* __restrict__ p_b0_grid,
|
||||
D0sPointer p_d0s_grid,
|
||||
const A0B0B1DataType* __restrict__ p_b1_grid,
|
||||
D1sPointer p_d1s_grid,
|
||||
E1DataType* __restrict__ p_e1_grid,
|
||||
const A0ElementwiseOperation a0_element_op,
|
||||
const B0ElementwiseOperation b0_element_op,
|
||||
const CDE0ElementwiseOperation cde0_element_op,
|
||||
const B1ElementwiseOperation b1_element_op,
|
||||
const CDE1ElementwiseOperation cde1_element_op,
|
||||
const A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1,
|
||||
const B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1,
|
||||
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
|
||||
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
|
||||
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
|
||||
const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
d1s_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e1_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2E1TileMap block_2_e1tile_map,
|
||||
const index_t batch_count,
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
@@ -829,10 +829,8 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
is_same_v<tensor_layout::gemm::ColumnMajor, B0Layout> &&
|
||||
CheckDLayout<tensor_layout::gemm::RowMajor, D0sLayout, NumD0Tensor>() &&
|
||||
(is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ||
|
||||
is_same_v<tensor_layout::gemm::ColumnMajor,
|
||||
B1Layout>)&&CheckDLayout<tensor_layout::gemm::RowMajor,
|
||||
D1sLayout,
|
||||
NumD1Tensor>() &&
|
||||
is_same_v<tensor_layout::gemm::ColumnMajor, B1Layout>) &&
|
||||
CheckDLayout<tensor_layout::gemm::RowMajor, D1sLayout, NumD1Tensor>() &&
|
||||
is_same_v<tensor_layout::gemm::RowMajor, E1Layout>))
|
||||
{
|
||||
return false;
|
||||
|
||||
@@ -33,9 +33,9 @@ template <typename GridwiseGemm,
|
||||
TailNumber TailNum = TailNumber::Full>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_batched_gemm_xdl_cshuffle_v3_multi_d(BatchedGemmArg karg)
|
||||
kernel_batched_gemm_xdl_cshuffle_v3_multi_d(BatchedGemmArg karg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
@@ -79,9 +79,9 @@ template <typename GridwiseGemm,
|
||||
TailNumber TailNum = TailNumber::Full>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds(BatchedGemmArg karg)
|
||||
kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds(BatchedGemmArg karg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
// Pass two lds pointer is the key to tell compiler that ds_read/write
|
||||
|
||||
@@ -39,26 +39,26 @@ template <typename GridwiseGemm,
|
||||
bool HasMainK0BlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_reduce_xdl_cshuffle_v1(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
ReducePtrsGlobal p_reduces_grid,
|
||||
const index_t batch_count,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const ReduceInElementwiseOperations reduce_in_element_ops,
|
||||
const ReduceAccElementwiseOperations reduce_out_element_ops,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock,
|
||||
const ComputeBasePrtOfBatch compute_base_ptr_of_batch_,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
kernel_batched_gemm_reduce_xdl_cshuffle_v1(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
ReducePtrsGlobal p_reduces_grid,
|
||||
const index_t batch_count,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const ReduceInElementwiseOperations reduce_in_element_ops,
|
||||
const ReduceAccElementwiseOperations reduce_out_element_ops,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock,
|
||||
const ComputeBasePrtOfBatch compute_base_ptr_of_batch_,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
const index_t num_blocks_per_batch =
|
||||
|
||||
@@ -40,21 +40,21 @@ template <typename DeviceOp,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_softmax_gemm_wmma_cshuffle(const ADataType* __restrict__ p_a_grid,
|
||||
const B0DataType* __restrict__ p_b0_grid,
|
||||
const B1DataType* __restrict__ p_b1_grid,
|
||||
CDataType* __restrict__ p_c_grid,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t O,
|
||||
index_t G0,
|
||||
index_t G1,
|
||||
float alpha,
|
||||
bool input_permute,
|
||||
bool output_permute)
|
||||
kernel_batched_gemm_softmax_gemm_wmma_cshuffle(const ADataType* __restrict__ p_a_grid,
|
||||
const B0DataType* __restrict__ p_b0_grid,
|
||||
const B1DataType* __restrict__ p_b1_grid,
|
||||
CDataType* __restrict__ p_c_grid,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t O,
|
||||
index_t G0,
|
||||
index_t G1,
|
||||
float alpha,
|
||||
bool input_permute,
|
||||
bool output_permute)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
|
||||
@@ -178,15 +178,15 @@ template <typename DeviceOp,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_wmma_self_attention_forward(const QKVDataType* __restrict__ p_qkv_grid,
|
||||
ODataType* __restrict__ p_out_grid,
|
||||
index_t batch_size,
|
||||
index_t sequence_length,
|
||||
index_t head_count,
|
||||
index_t head_size,
|
||||
float alpha)
|
||||
kernel_wmma_self_attention_forward(const QKVDataType* __restrict__ p_qkv_grid,
|
||||
ODataType* __restrict__ p_out_grid,
|
||||
index_t batch_size,
|
||||
index_t sequence_length,
|
||||
index_t head_count,
|
||||
index_t head_size,
|
||||
float alpha)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
|
||||
@@ -310,17 +310,17 @@ template <typename DeviceOp,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_wmma_cross_attention_forward(const QDataType* __restrict__ p_q_grid,
|
||||
const KVDataType* __restrict__ p_kv_grid,
|
||||
ODataType* __restrict__ p_out_grid,
|
||||
index_t batch_size,
|
||||
index_t q_sequence_length,
|
||||
index_t kv_sequence_length,
|
||||
index_t head_count,
|
||||
index_t head_size,
|
||||
float alpha)
|
||||
kernel_wmma_cross_attention_forward(const QDataType* __restrict__ p_q_grid,
|
||||
const KVDataType* __restrict__ p_kv_grid,
|
||||
ODataType* __restrict__ p_out_grid,
|
||||
index_t batch_size,
|
||||
index_t q_sequence_length,
|
||||
index_t kv_sequence_length,
|
||||
index_t head_count,
|
||||
index_t head_size,
|
||||
float alpha)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
|
||||
|
||||
@@ -43,30 +43,30 @@ template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
const FloatAB* __restrict__ p_b1_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
D0sPointer p_d0s_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const C0DEElementwiseOperation c0de_element_op,
|
||||
const B1ElementwiseOperation b1_element_op,
|
||||
const C1DEElementwiseOperation c1de_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
|
||||
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
|
||||
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const index_t batch_count,
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
|
||||
const C0MatrixMask c0_matrix_mask)
|
||||
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
const FloatAB* __restrict__ p_b1_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
D0sPointer p_d0s_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const C0DEElementwiseOperation c0de_element_op,
|
||||
const B1ElementwiseOperation b1_element_op,
|
||||
const C1DEElementwiseOperation c1de_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
|
||||
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c1_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
|
||||
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const index_t batch_count,
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
|
||||
const C0MatrixMask c0_matrix_mask)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
@@ -42,27 +42,27 @@ template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
const FloatAB* __restrict__ p_b1_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const AccElementwiseOperation acc_element_op,
|
||||
const B1ElementwiseOperation b1_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const index_t batch_count,
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
|
||||
const C0MatrixMask c0_matrix_mask)
|
||||
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
const FloatAB* __restrict__ p_b1_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const AccElementwiseOperation acc_element_op,
|
||||
const B1ElementwiseOperation b1_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const index_t batch_count,
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
|
||||
const C0MatrixMask c0_matrix_mask)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
@@ -29,14 +29,13 @@ template <typename GridwiseGemm,
|
||||
TailNumber TailNum = TailNumber::Full>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_batched_gemm_wmma_cshuffle_v3(
|
||||
typename GridwiseGemm::Argument
|
||||
karg, // This works for now but it actually receives a
|
||||
// DeviceBatchedGemm_Wmma_CShuffleV3::Argument
|
||||
// argument through implicit conversion to base class!
|
||||
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
|
||||
kernel_batched_gemm_wmma_cshuffle_v3(
|
||||
typename GridwiseGemm::Argument karg, // This works for now but it actually receives a
|
||||
// DeviceBatchedGemm_Wmma_CShuffleV3::Argument
|
||||
// argument through implicit conversion to base class!
|
||||
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
#if defined(__gfx11__)
|
||||
|
||||
@@ -48,9 +48,9 @@ namespace device {
|
||||
template <typename DeviceOp, typename GridwiseGemm, bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg)
|
||||
kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
const index_t num_blocks_per_batch =
|
||||
|
||||
@@ -33,9 +33,9 @@ template <typename GridwiseGemm,
|
||||
TailNumber TailNum = TailNumber::Full>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_batched_gemm_b_scale_xdl_cshuffle_v3(BatchedGemmArg karg)
|
||||
kernel_batched_gemm_b_scale_xdl_cshuffle_v3(BatchedGemmArg karg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
@@ -71,9 +71,9 @@ template <typename GridwiseGemm,
|
||||
TailNumber TailNum = TailNumber::Full>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds(BatchedGemmArg karg)
|
||||
kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds(BatchedGemmArg karg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
// Pass two lds pointer is the key to tell compiler that ds_read/write
|
||||
|
||||
@@ -610,8 +610,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
if(!parg)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Provided argument pointer is not of an Argument class!"
|
||||
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
err << "Provided argument pointer is not of an Argument class!" << " In " << __FILE__
|
||||
<< ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
|
||||
@@ -467,12 +467,12 @@ struct DeviceColumnToImageImpl
|
||||
|
||||
float elapsed_time = 0.f;
|
||||
const auto kernel = kernel_tensor_rearrange<InputGridDesc,
|
||||
InputDataType,
|
||||
OutputGridDesc,
|
||||
OutputDataType,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<>,
|
||||
GridwiseTensorRearrangeKernel>;
|
||||
InputDataType,
|
||||
OutputGridDesc,
|
||||
OutputDataType,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<>,
|
||||
GridwiseTensorRearrangeKernel>;
|
||||
|
||||
// Execute each set of independent filters
|
||||
for(std::size_t i = 0; i < arg.in_grid_desc_m_k_container_.size(); i++)
|
||||
|
||||
@@ -37,23 +37,23 @@ template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_contraction_multiple_abd_xdl_cshuffle(
|
||||
AsPointer p_as_grid,
|
||||
BsPointer p_bs_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1,
|
||||
const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
kernel_contraction_multiple_abd_xdl_cshuffle(
|
||||
AsPointer p_as_grid,
|
||||
BsPointer p_bs_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1,
|
||||
const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
@@ -35,23 +35,23 @@ template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_contraction_multiple_d_xdl_cshuffle(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatDsPointer p_ds_grid,
|
||||
FloatE* __restrict__ p_e_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
kernel_contraction_multiple_d_xdl_cshuffle(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatDsPointer p_ds_grid,
|
||||
FloatE* __restrict__ p_e_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
@@ -35,17 +35,15 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind
|
||||
if(lengths.size() != NumDim1 + NumDim2)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Incorrect number of lengths in "
|
||||
<< "device_contraction_utils.hpp"
|
||||
<< ":" << __LINE__ << ", in function: " << __func__;
|
||||
err << "Incorrect number of lengths in " << "device_contraction_utils.hpp" << ":"
|
||||
<< __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
if(strides.size() != NumDim1 + NumDim2)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Incorrect number of strides in "
|
||||
<< "device_contraction_utils.hpp"
|
||||
<< ":" << __LINE__ << ", in function: " << __func__;
|
||||
err << "Incorrect number of strides in " << "device_contraction_utils.hpp" << ":"
|
||||
<< __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
|
||||
@@ -648,9 +648,8 @@ struct
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << DeviceOp{}.GetTypeString() << std::endl;
|
||||
std::cout << "N " << arg.Conv_N_ << ", "
|
||||
<< "K " << arg.Conv_K_ << ", "
|
||||
<< "C " << arg.Conv_C_ << ", " << std::endl;
|
||||
std::cout << "N " << arg.Conv_N_ << ", " << "K " << arg.Conv_K_ << ", " << "C "
|
||||
<< arg.Conv_C_ << ", " << std::endl;
|
||||
std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", "
|
||||
<< arg.filter_spatial_lengths_[1] << ", " << std::endl;
|
||||
std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", "
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user