Added Multi_ABD support into Gemm and GroupedGemmFixedNK (#978)

* added an example grouped_gemm_multi_abd

* fixed ci

* add setElementwiseOp

* changed API

* clean code: add multiA into example

* fixed v7r2 copy

* add transpose

* clean

* fixed vector_load check

* Update example/15_grouped_gemm/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update example/15_grouped_gemm/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update example/15_grouped_gemm/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* add reduce

* testing

* add example_b16_i8

* refactor example

* clean

* add mpading

* disable reduce for kbatch = 1

* seperate reduce device op

* add reduce op

* add guard for workspace_size

* add instances

* format

* fixed

* add client example

* add a colmajor

* add instances

* Update cmake-ck-dev.sh

* Update profile_gemm_splitk.cpp

* Update gridwise_gemm_xdlops_v2r4r2.hpp

* format

* Update profile_gemm_splitk.cpp

* fixed

* fixed

* adjust test

* adjust precision loss

* adjust test

* fixed

* add bf16_i8 scale bias

* fixed scale

* fixed scale elementwise_op

* revert contraction deviceop changes

* fixed

* Add AddFastGelu

* Revert "Merge branch 'jizhan/gemm_splitk_reduce' into grouped_gemm_multi_abd_fixed_nk_example"

This reverts commit 3b5d001efd, reversing
changes made to 943199a991.

* add Scales into elementwise

* add gemm_multi_abd client example

* add client examples

* add rcr and crr

* add grouped gemm client example

* add grouped gemm client example

* add instance for rcr crr

* format

* fixed

* fixed cmake

* fixed

* fixed client_example

* format

* fixed contraction isSupport

* Update include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp

Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>

* Update device_reduce_threadwise.hpp

* clean

* Fixes

* Fix example

---------

Co-authored-by: Jing Zhang <jizha@amd.com>
Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>
This commit is contained in:
zjing14
2024-04-15 21:09:45 -05:00
committed by GitHub
parent db376dd8a4
commit 12865fbf28
45 changed files with 6345 additions and 199 deletions

View File

@@ -0,0 +1,13 @@
if(GPU_TARGETS MATCHES "gfx9" AND ((DTYPES MATCHES "int8" AND DTYPES MATCHES "bf16") OR NOT DEFINED DTYPES))
add_executable(client_gemm_bias_fastgelu_bf16_i8_bf16 gemm_bias_fastgelu_xdl_bf16_i8.cpp)
target_link_libraries(client_gemm_bias_fastgelu_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_gemm_bias_bf16_i8_bf16 gemm_bias_xdl_bf16_i8.cpp)
target_link_libraries(client_gemm_bias_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_gemm_gelu_bf16_i8_bf16 gemm_xdl_gelu_bf16_i8.cpp)
target_link_libraries(client_gemm_gelu_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_gemm_bf16_i8_bf16 gemm_xdl_bf16_i8.cpp)
target_link_libraries(client_gemm_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations)
endif()

View File

@@ -0,0 +1,262 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <iomanip>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = BF16;
using AsDataType = ck::Tuple<A0DataType>;
using B0DataType = I8;
using B1DataType = BF16;
using BsDataType = ck::Tuple<B0DataType, B1DataType>;
using AccDataType = F32;
using CShuffleDataType = BF16;
using D0DataType = BF16;
using DsDataType = ck::Tuple<D0DataType>;
using EDataType = BF16;
using A0Layout = Row;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Col;
using B1Layout = B0Layout;
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using AElementOp = PassThrough;
using BElementOp = Scales;
using CDEElementOp = AddFastGelu;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
// clang-format on
int main(int argc, char* argv[])
{
// GEMM shape
ck::index_t M = 64;
ck::index_t N = 1024;
ck::index_t K = 512;
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideD = N;
ck::index_t StrideE = N;
if(argc == 1)
{
// use default case
}
else if(argc == 8)
{
M = std::stoi(argv[1]);
N = std::stoi(argv[2]);
K = std::stoi(argv[3]);
StrideA = std::stoi(argv[4]);
StrideB = std::stoi(argv[5]);
StrideD = std::stoi(argv[6]);
StrideE = std::stoi(argv[7]);
}
else
{
printf("arg1 to 7: M, N, K, StrideA, StrideB, StrideD, StrideE\n");
exit(0);
}
auto f_matrix_space_size =
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
using Layout = decltype(layout);
if constexpr(std::is_same<Layout, Row>::value)
{
return (nRow - 1) * stride + nCol;
}
else
{
return (nCol - 1) * stride + nRow;
}
};
SimpleDeviceMem a0_device_buf(sizeof(A0DataType) *
f_matrix_space_size(M, K, StrideA, A0Layout{}));
SimpleDeviceMem b0_device_buf(sizeof(B0DataType) *
f_matrix_space_size(K, N, StrideB, B0Layout{}));
SimpleDeviceMem b1_device_buf(sizeof(B1DataType) * f_matrix_space_size(K, N, 0, B1Layout{}));
SimpleDeviceMem d0_device_buf(sizeof(D0DataType) *
f_matrix_space_size(M, N, StrideD, ELayout{}));
SimpleDeviceMem e_device_buf(sizeof(EDataType) * f_matrix_space_size(M, N, StrideE, ELayout{}));
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
constexpr ck::index_t NumATensor = 1;
constexpr ck::index_t NumBTensor = 2;
constexpr ck::index_t NumDTensor = 1;
using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
Row,
AsDataType,
BsDataType,
DsDataType,
BF16,
AElementOp,
BElementOp,
CDEElementOp>;
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
bool found = false;
int best_op_id = -1;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(
std::array<const void*, NumATensor>{a0_device_buf.GetDeviceBuffer()},
std::array<const void*, NumBTensor>{b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer()},
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
std::array<ck::index_t, NumATensor>{StrideA},
std::array<ck::index_t, NumBTensor>{StrideB, 0},
std::array<ck::index_t, NumDTensor>{StrideD},
StrideE,
a_element_op,
b_element_op,
cde_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
found = true;
best_op_id = i;
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
// run the best intance
if(found)
{
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer(
std::array<const void*, NumATensor>{a0_device_buf.GetDeviceBuffer()},
std::array<const void*, NumBTensor>{b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer()},
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
std::array<ck::index_t, NumATensor>{StrideA},
std::array<ck::index_t, NumBTensor>{StrideB, 0},
std::array<ck::index_t, NumDTensor>{StrideD},
StrideE,
a_element_op,
b_element_op,
cde_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return 0;
}

View File

@@ -0,0 +1,262 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <iomanip>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = BF16;
using AsDataType = ck::Tuple<A0DataType>;
using B0DataType = I8;
using B1DataType = BF16;
using BsDataType = ck::Tuple<B0DataType, B1DataType>;
using AccDataType = F32;
using CShuffleDataType = BF16;
using D0DataType = BF16;
using DsDataType = ck::Tuple<D0DataType>;
using EDataType = BF16;
using A0Layout = Col;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Row;
using B1Layout = B0Layout;
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
using AElementOp = PassThrough;
using BElementOp = Scales;
using CDEElementOp = Add;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
// clang-format on
int main(int argc, char* argv[])
{
// GEMM shape
ck::index_t M = 64;
ck::index_t N = 1024;
ck::index_t K = 512;
ck::index_t StrideA = M;
ck::index_t StrideB = N;
ck::index_t StrideD = N;
ck::index_t StrideE = N;
if(argc == 1)
{
// use default case
}
else if(argc == 8)
{
M = std::stoi(argv[1]);
N = std::stoi(argv[2]);
K = std::stoi(argv[3]);
StrideA = std::stoi(argv[4]);
StrideB = std::stoi(argv[5]);
StrideD = std::stoi(argv[6]);
StrideE = std::stoi(argv[7]);
}
else
{
printf("arg1 to 7: M, N, K, StrideA, StrideB, StrideD, StrideE\n");
exit(0);
}
auto f_matrix_space_size =
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
using Layout = decltype(layout);
if constexpr(std::is_same<Layout, Row>::value)
{
return (nRow - 1) * stride + nCol;
}
else
{
return (nCol - 1) * stride + nRow;
}
};
SimpleDeviceMem a0_device_buf(sizeof(A0DataType) *
f_matrix_space_size(M, K, StrideA, A0Layout{}));
SimpleDeviceMem b0_device_buf(sizeof(B0DataType) *
f_matrix_space_size(K, N, StrideB, B0Layout{}));
SimpleDeviceMem b1_device_buf(sizeof(B1DataType) * f_matrix_space_size(K, N, 0, B1Layout{}));
SimpleDeviceMem d0_device_buf(sizeof(D0DataType) *
f_matrix_space_size(M, N, StrideD, ELayout{}));
SimpleDeviceMem e_device_buf(sizeof(EDataType) * f_matrix_space_size(M, N, StrideE, ELayout{}));
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
constexpr ck::index_t NumATensor = 1;
constexpr ck::index_t NumBTensor = 2;
constexpr ck::index_t NumDTensor = 1;
using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
Row,
AsDataType,
BsDataType,
DsDataType,
BF16,
AElementOp,
BElementOp,
CDEElementOp>;
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
bool found = false;
int best_op_id = -1;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(
std::array<const void*, NumATensor>{a0_device_buf.GetDeviceBuffer()},
std::array<const void*, NumBTensor>{b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer()},
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
std::array<ck::index_t, NumATensor>{StrideA},
std::array<ck::index_t, NumBTensor>{StrideB, 0},
std::array<ck::index_t, NumDTensor>{StrideD},
StrideE,
a_element_op,
b_element_op,
cde_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
found = true;
best_op_id = i;
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
// run the best intance
if(found)
{
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer(
std::array<const void*, NumATensor>{a0_device_buf.GetDeviceBuffer()},
std::array<const void*, NumBTensor>{b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer()},
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
std::array<ck::index_t, NumATensor>{StrideA},
std::array<ck::index_t, NumBTensor>{StrideB, 0},
std::array<ck::index_t, NumDTensor>{StrideD},
StrideE,
a_element_op,
b_element_op,
cde_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return 0;
}

View File

@@ -0,0 +1,257 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <iomanip>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = BF16;
using AsDataType = ck::Tuple<A0DataType>;
using B0DataType = I8;
using B1DataType = BF16;
using BsDataType = ck::Tuple<B0DataType, B1DataType>;
using AccDataType = F32;
using CShuffleDataType = BF16;
using DsDataType = ck::Tuple<>;
using EDataType = BF16;
using A0Layout = Row;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Col;
using B1Layout = B0Layout;
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
using D0Layout = Row;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
using AElementOp = PassThrough;
using BElementOp = Scales;
using CDEElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
// clang-format on
int main(int argc, char* argv[])
{
// GEMM shape
ck::index_t M = 64;
ck::index_t N = 1024;
ck::index_t K = 512;
ck::index_t StrideA = K;
ck::index_t StrideB = N;
ck::index_t StrideE = N;
if(argc == 1)
{
// use default case
}
else if(argc == 7)
{
M = std::stoi(argv[1]);
N = std::stoi(argv[2]);
K = std::stoi(argv[3]);
StrideA = std::stoi(argv[4]);
StrideB = std::stoi(argv[5]);
StrideE = std::stoi(argv[6]);
}
else
{
printf("arg1 to 7: M, N, K, StrideA, StrideB, StrideE\n");
exit(0);
}
auto f_matrix_space_size =
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
using Layout = decltype(layout);
if constexpr(std::is_same<Layout, Row>::value)
{
return (nRow - 1) * stride + nCol;
}
else
{
return (nCol - 1) * stride + nRow;
}
};
SimpleDeviceMem a0_device_buf(sizeof(A0DataType) *
f_matrix_space_size(M, K, StrideA, A0Layout{}));
SimpleDeviceMem b0_device_buf(sizeof(B0DataType) *
f_matrix_space_size(K, N, StrideB, B0Layout{}));
SimpleDeviceMem b1_device_buf(sizeof(B1DataType) * f_matrix_space_size(K, N, 0, B1Layout{}));
SimpleDeviceMem e_device_buf(sizeof(EDataType) * f_matrix_space_size(M, N, StrideE, ELayout{}));
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
constexpr ck::index_t NumATensor = 1;
constexpr ck::index_t NumBTensor = 2;
constexpr ck::index_t NumDTensor = 0;
using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
Row,
AsDataType,
BsDataType,
DsDataType,
BF16,
AElementOp,
BElementOp,
CDEElementOp>;
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
bool found = false;
int best_op_id = -1;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(
std::array<const void*, NumATensor>{a0_device_buf.GetDeviceBuffer()},
std::array<const void*, NumBTensor>{b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer()},
std::array<const void*, NumDTensor>{},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
std::array<ck::index_t, NumATensor>{StrideA},
std::array<ck::index_t, NumBTensor>{StrideB, 0},
std::array<ck::index_t, NumDTensor>{},
StrideE,
a_element_op,
b_element_op,
cde_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
found = true;
best_op_id = i;
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
// run the best intance
if(found)
{
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer(
std::array<const void*, NumATensor>{a0_device_buf.GetDeviceBuffer()},
std::array<const void*, NumBTensor>{b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer()},
std::array<const void*, NumDTensor>{},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
std::array<ck::index_t, NumATensor>{StrideA},
std::array<ck::index_t, NumBTensor>{StrideB, 0},
std::array<ck::index_t, NumDTensor>{},
StrideE,
a_element_op,
b_element_op,
cde_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return 0;
}

View File

@@ -0,0 +1,261 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <iomanip>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = BF16;
using AsDataType = ck::Tuple<A0DataType>;
using B0DataType = I8;
using B1DataType = BF16;
using BsDataType = ck::Tuple<B0DataType, B1DataType>;
using AccDataType = F32;
using CShuffleDataType = BF16;
using DsDataType = ck::Tuple<>;
using EDataType = BF16;
using A0Layout = Row;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Col;
using B1Layout = B0Layout;
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
using D0Layout = Row;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AElementOp = PassThrough;
using BElementOp = Scales;
using CDEElementOp = FastGelu;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 64;
ck::index_t N = 1024;
ck::index_t K = 512;
ck::index_t StrideA = K;
ck::index_t StrideB = N;
ck::index_t StrideE = N;
if(argc == 1)
{
// use default case
}
else if(argc == 7)
{
M = std::stoi(argv[1]);
N = std::stoi(argv[2]);
K = std::stoi(argv[3]);
StrideA = std::stoi(argv[4]);
StrideB = std::stoi(argv[5]);
StrideE = std::stoi(argv[6]);
}
else
{
printf("arg1 to 7: M, N, K, StrideA, StrideB, StrideE\n");
exit(0);
}
auto f_matrix_space_size =
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
using Layout = decltype(layout);
if constexpr(std::is_same<Layout, Row>::value)
{
return (nRow - 1) * stride + nCol;
}
else
{
return (nCol - 1) * stride + nRow;
}
};
SimpleDeviceMem a0_device_buf(sizeof(A0DataType) *
f_matrix_space_size(M, K, StrideA, A0Layout{}));
SimpleDeviceMem b0_device_buf(sizeof(B0DataType) *
f_matrix_space_size(K, N, StrideB, B0Layout{}));
SimpleDeviceMem b1_device_buf(sizeof(B1DataType) * f_matrix_space_size(K, N, 0, B1Layout{}));
SimpleDeviceMem e_device_buf(sizeof(EDataType) * f_matrix_space_size(M, N, StrideE, ELayout{}));
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
constexpr ck::index_t NumATensor = 1;
constexpr ck::index_t NumBTensor = 2;
constexpr ck::index_t NumDTensor = 0;
using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
Row,
AsDataType,
BsDataType,
DsDataType,
BF16,
AElementOp,
BElementOp,
CDEElementOp>;
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
bool found = false;
int best_op_id = -1;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(
std::array<const void*, NumATensor>{a0_device_buf.GetDeviceBuffer()},
std::array<const void*, NumBTensor>{b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer()},
std::array<const void*, NumDTensor>{},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
std::array<ck::index_t, NumATensor>{StrideA},
std::array<ck::index_t, NumBTensor>{StrideB, 0},
std::array<ck::index_t, NumDTensor>{},
StrideE,
a_element_op,
b_element_op,
cde_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
found = true;
best_op_id = i;
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
// run the best intance
if(found)
{
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer(
std::array<const void*, NumATensor>{a0_device_buf.GetDeviceBuffer()},
std::array<const void*, NumBTensor>{b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer()},
std::array<const void*, NumDTensor>{},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
std::array<ck::index_t, NumATensor>{StrideA},
std::array<ck::index_t, NumBTensor>{StrideB, 0},
std::array<ck::index_t, NumDTensor>{},
StrideE,
a_element_op,
b_element_op,
cde_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return 0;
}

View File

@@ -0,0 +1,7 @@
if(GPU_TARGETS MATCHES "gfx9" AND ((DTYPES MATCHES "int8" AND DTYPES MATCHES "bf16") OR NOT DEFINED DTYPES))
add_executable(client_grouped_gemm_bias_fastgelu_bf16_i8_bf16 grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp)
target_link_libraries(client_grouped_gemm_bias_fastgelu_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations)
add_executable(client_grouped_gemm_fastgelu_bf16_i8_bf16 grouped_gemm_fastgelu_xdl_bf16_i8.cpp)
target_link_libraries(client_grouped_gemm_fastgelu_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations)
endif()

View File

@@ -0,0 +1,286 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <iomanip>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = BF16;
using AsDataType = ck::Tuple<A0DataType>;
using B0DataType = I8;
using B1DataType = BF16;
using BsDataType = ck::Tuple<B0DataType, B1DataType>;
using AccDataType = F32;
using CShuffleDataType = BF16;
using D0DataType = BF16;
using DsDataType = ck::Tuple<D0DataType>;
using EDataType = BF16;
using A0Layout = Row;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Col;
using B1Layout = B0Layout;
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using AElementOp = PassThrough;
using BElementOp = Scales;
using CDEElementOp = AddFastGelu;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
struct ProblemSize final
{
std::vector<ck::index_t> Ms;
std::vector<ck::index_t> Ns;
std::vector<ck::index_t> Ks;
std::vector<ck::index_t> stride_As;
std::vector<ck::index_t> stride_Bs;
std::vector<ck::index_t> stride_Cs;
ck::index_t group_count;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
int k_batch = 1;
};
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
auto group_count = problem_size.group_count;
// GEMM shape
std::vector<ck::tensor_operation::device::GemmMultiABDDesc> gemm_descs;
gemm_descs.reserve(group_count);
int sum_of_m = 0;
using DeviceMemPtr = std::unique_ptr<SimpleDeviceMem>;
std::vector<DeviceMemPtr> a0_tensors_device, b0_tensors_device, b1_tensors_device,
d0_tensors_device, c_tensors_device;
a0_tensors_device.reserve(group_count);
b0_tensors_device.reserve(group_count);
b1_tensors_device.reserve(group_count);
d0_tensors_device.reserve(group_count);
c_tensors_device.reserve(group_count);
std::size_t flop = 0, num_btype = 0;
for(int i = 0; i < group_count; i++)
{
sum_of_m += problem_size.Ms[i];
}
constexpr ck::index_t NumATensor = 1;
constexpr ck::index_t NumBTensor = 2;
constexpr ck::index_t NumDTensor = 1;
using GroupedGemmKernelArgument = ck::tensor_operation::device::
GroupedGemmMultiABDKernelArgument<NumATensor, NumBTensor, NumDTensor>;
std::vector<GroupedGemmKernelArgument> grouped_gemm_kernel_args_;
grouped_gemm_kernel_args_.reserve(group_count);
for(int i = 0; i < group_count; i++)
{
a0_tensors_device.emplace_back(
std::make_unique<SimpleDeviceMem>(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i]));
b0_tensors_device.emplace_back(std::make_unique<SimpleDeviceMem>(
sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i]));
b1_tensors_device.emplace_back(
std::make_unique<SimpleDeviceMem>(sizeof(B1DataType) * problem_size.Ns[i]));
d0_tensors_device.emplace_back(
std::make_unique<SimpleDeviceMem>(sizeof(D0DataType) * problem_size.Ns[i]));
c_tensors_device.emplace_back(
std::make_unique<SimpleDeviceMem>(sizeof(EDataType) * sum_of_m * problem_size.Ns[i]));
gemm_descs.push_back(
{sum_of_m, problem_size.Ns[i], problem_size.Ks[i], {1}, {1, 1}, {0}, 1});
grouped_gemm_kernel_args_.push_back(
{std::array<const void*, NumATensor>{a0_tensors_device[i]->GetDeviceBuffer()},
std::array<const void*, NumBTensor>{b0_tensors_device[i]->GetDeviceBuffer(),
b1_tensors_device[i]->GetDeviceBuffer()},
std::array<const void*, NumDTensor>{d0_tensors_device[i]->GetDeviceBuffer()},
c_tensors_device[i]->GetDeviceBuffer(),
problem_size.Ms[i],
problem_size.Ns[i],
problem_size.Ks[i],
std::array<ck::index_t, NumATensor>{problem_size.stride_As[i]},
std::array<ck::index_t, NumBTensor>{problem_size.stride_Bs[i], 0},
std::array<ck::index_t, NumDTensor>{0},
problem_size.stride_Cs[i]});
}
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmMultiABDFixedNK<AsLayout,
BsLayout,
DsLayout,
Row,
AsDataType,
BsDataType,
DsDataType,
BF16,
AElementOp,
BElementOp,
CDEElementOp>;
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
bool found = false;
int best_op_id = -1;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
std::vector<std::array<const void*, NumATensor>> p_As = {};
std::vector<std::array<const void*, NumBTensor>> p_Bs = {};
std::vector<std::array<const void*, NumDTensor>> p_Ds = {};
std::vector<void*> p_Cs = {};
auto argument_ptr = op_ptr->MakeArgumentPointer(p_As, p_Bs, p_Ds, p_Cs, gemm_descs);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
SimpleDeviceMem gemm_kernel_args_dev(
op_ptr->GetDeviceKernelArgSize(argument_ptr.get()));
hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(),
grouped_gemm_kernel_args_.data(),
op_ptr->GetDeviceKernelArgSize(argument_ptr.get()),
hipMemcpyHostToDevice));
op_ptr->SetDeviceKernelArgs(argument_ptr.get(), gemm_kernel_args_dev.GetDeviceBuffer());
op_ptr->SetElementwiseOps(
argument_ptr.get(), a_element_op, b_element_op, cde_element_op);
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop = std::size_t(2) * sum_of_m * problem_size.Ns[0] * problem_size.Ks[0];
std::size_t num_btype = sizeof(A0DataType) * sum_of_m * problem_size.Ks[0] +
sizeof(B0DataType) * problem_size.Ks[0] * problem_size.Ns[0] +
sizeof(EDataType) * sum_of_m * problem_size.Ns[0];
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
found = true;
best_op_id = i;
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return true;
}
int main(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
problem_size.group_count = 16;
for(int i = 0; i < problem_size.group_count; i++)
{
problem_size.Ms.push_back(32 + rand() % 32);
problem_size.Ns.push_back(1024);
problem_size.Ks.push_back(512);
problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ns[i]);
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
}
return !run_grouped_gemm(problem_size, config);
}

View File

@@ -0,0 +1,282 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <iomanip>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_multi_abd_fixed_nk.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = BF16;
using AsDataType = ck::Tuple<A0DataType>;
using B0DataType = I8;
using B1DataType = BF16;
using BsDataType = ck::Tuple<B0DataType, B1DataType>;
using AccDataType = F32;
using CShuffleDataType = BF16;
using D0DataType = BF16;
using DsDataType = ck::Tuple<>;
using EDataType = BF16;
using A0Layout = Col;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Row;
using B1Layout = B0Layout;
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
using D0Layout = Row;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AElementOp = PassThrough;
using BElementOp = Scales;
using CDEElementOp = FastGelu;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
struct ProblemSize final
{
std::vector<ck::index_t> Ms;
std::vector<ck::index_t> Ns;
std::vector<ck::index_t> Ks;
std::vector<ck::index_t> stride_As;
std::vector<ck::index_t> stride_Bs;
std::vector<ck::index_t> stride_Cs;
ck::index_t group_count;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
int k_batch = 1;
};
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
auto group_count = problem_size.group_count;
// GEMM shape
std::vector<ck::tensor_operation::device::GemmMultiABDDesc> gemm_descs;
gemm_descs.reserve(group_count);
int sum_of_m = 0;
using DeviceMemPtr = std::unique_ptr<SimpleDeviceMem>;
std::vector<DeviceMemPtr> a0_tensors_device, b0_tensors_device, b1_tensors_device,
c_tensors_device;
a0_tensors_device.reserve(group_count);
b0_tensors_device.reserve(group_count);
b1_tensors_device.reserve(group_count);
c_tensors_device.reserve(group_count);
std::size_t flop = 0, num_btype = 0;
for(int i = 0; i < group_count; i++)
{
sum_of_m += problem_size.Ms[i];
}
constexpr ck::index_t NumATensor = 1;
constexpr ck::index_t NumBTensor = 2;
constexpr ck::index_t NumDTensor = 0;
using GroupedGemmKernelArgument = ck::tensor_operation::device::
GroupedGemmMultiABDKernelArgument<NumATensor, NumBTensor, NumDTensor>;
std::vector<GroupedGemmKernelArgument> grouped_gemm_kernel_args_;
grouped_gemm_kernel_args_.reserve(group_count);
for(int i = 0; i < group_count; i++)
{
a0_tensors_device.emplace_back(
std::make_unique<SimpleDeviceMem>(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i]));
b0_tensors_device.emplace_back(std::make_unique<SimpleDeviceMem>(
sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i]));
b1_tensors_device.emplace_back(
std::make_unique<SimpleDeviceMem>(sizeof(B1DataType) * problem_size.Ns[i]));
c_tensors_device.emplace_back(
std::make_unique<SimpleDeviceMem>(sizeof(EDataType) * sum_of_m * problem_size.Ns[i]));
gemm_descs.push_back(
{sum_of_m, problem_size.Ns[i], problem_size.Ks[i], {1}, {1, 1}, {}, 1});
grouped_gemm_kernel_args_.push_back(
{std::array<const void*, NumATensor>{a0_tensors_device[i]->GetDeviceBuffer()},
std::array<const void*, NumBTensor>{b0_tensors_device[i]->GetDeviceBuffer(),
b1_tensors_device[i]->GetDeviceBuffer()},
std::array<const void*, NumDTensor>{},
c_tensors_device[i]->GetDeviceBuffer(),
problem_size.Ms[i],
problem_size.Ns[i],
problem_size.Ks[i],
std::array<ck::index_t, NumATensor>{problem_size.stride_As[i]},
std::array<ck::index_t, NumBTensor>{problem_size.stride_Bs[i], 0},
std::array<ck::index_t, NumDTensor>{},
problem_size.stride_Cs[i]});
}
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmMultiABDFixedNK<AsLayout,
BsLayout,
DsLayout,
Row,
AsDataType,
BsDataType,
DsDataType,
BF16,
AElementOp,
BElementOp,
CDEElementOp>;
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
bool found = false;
int best_op_id = -1;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
std::vector<std::array<const void*, NumATensor>> p_As = {};
std::vector<std::array<const void*, NumBTensor>> p_Bs = {};
std::vector<std::array<const void*, NumDTensor>> p_Ds = {};
std::vector<void*> p_Cs = {};
auto argument_ptr = op_ptr->MakeArgumentPointer(p_As, p_Bs, p_Ds, p_Cs, gemm_descs);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
SimpleDeviceMem gemm_kernel_args_dev(
op_ptr->GetDeviceKernelArgSize(argument_ptr.get()));
hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(),
grouped_gemm_kernel_args_.data(),
op_ptr->GetDeviceKernelArgSize(argument_ptr.get()),
hipMemcpyHostToDevice));
op_ptr->SetDeviceKernelArgs(argument_ptr.get(), gemm_kernel_args_dev.GetDeviceBuffer());
op_ptr->SetElementwiseOps(
argument_ptr.get(), a_element_op, b_element_op, cde_element_op);
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop = std::size_t(2) * sum_of_m * problem_size.Ns[0] * problem_size.Ks[0];
std::size_t num_btype = sizeof(A0DataType) * sum_of_m * problem_size.Ks[0] +
sizeof(B0DataType) * problem_size.Ks[0] * problem_size.Ns[0] +
sizeof(EDataType) * sum_of_m * problem_size.Ns[0];
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
found = true;
best_op_id = i;
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return true;
}
int main(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
problem_size.group_count = 16;
for(int i = 0; i < problem_size.group_count; i++)
{
problem_size.Ms.push_back(32 + rand() % 32);
problem_size.Ns.push_back(1024);
problem_size.Ks.push_back(512);
problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ns[i]);
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
}
return !run_grouped_gemm(problem_size, config);
}

View File

@@ -0,0 +1,7 @@
add_custom_target(example_grouped_gemm_xdl_multi_abd)
add_example_executable(example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16 grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16)
add_example_executable(example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8 grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8)

View File

@@ -0,0 +1,401 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
using A0DataType = BF16;
using AsDataType = ck::Tuple<A0DataType>;
using B0DataType = I8;
using B1DataType = BF16;
using BsDataType = ck::Tuple<B0DataType, B1DataType>;
using AccDataType = F32;
using CShuffleDataType = BF16;
using D0DataType = BF16;
using DsDataType = ck::Tuple<D0DataType>;
using EDataType = BF16;
using A0Layout = Row;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Col;
using B1Layout = B0Layout;
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
using DsLayout = ck::Tuple<Row>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using AElementOp = PassThrough;
using BElementOp = Scales;
using CDEElementOp = AddFastGelu;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
// clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 128, 16, 128, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>;
// clang-format on
struct ProblemSize final
{
std::vector<ck::index_t> Ms;
std::vector<ck::index_t> Ns;
std::vector<ck::index_t> Ks;
std::vector<ck::index_t> stride_As;
std::vector<ck::index_t> stride_Bs;
std::vector<ck::index_t> stride_Cs;
ck::index_t group_count;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
int k_batch = 1;
};
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
auto group_count = problem_size.group_count;
// GEMM shape
std::vector<ck::tensor_operation::device::GemmMultiABDDesc> gemm_descs;
gemm_descs.reserve(group_count);
int sum_of_m = 0;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
std::vector<Tensor<A0DataType>> a0_tensors;
std::vector<Tensor<B1DataType>> b_tensors;
std::vector<Tensor<B0DataType>> b0_tensors;
std::vector<Tensor<B1DataType>> b1_tensors;
std::vector<Tensor<D0DataType>> d0_tensors;
std::vector<Tensor<EDataType>> c_host_tensors;
std::vector<Tensor<EDataType>> c_device_tensors;
a0_tensors.reserve(group_count);
b_tensors.reserve(group_count);
b0_tensors.reserve(group_count);
b1_tensors.reserve(group_count);
d0_tensors.reserve(group_count);
c_host_tensors.reserve(group_count);
c_device_tensors.reserve(group_count);
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a0_tensors_device, b0_tensors_device, b1_tensors_device,
d0_tensors_device, c_tensors_device;
a0_tensors_device.reserve(group_count);
b0_tensors_device.reserve(group_count);
b1_tensors_device.reserve(group_count);
d0_tensors_device.reserve(group_count);
c_tensors_device.reserve(group_count);
std::size_t flop = 0, num_btype = 0;
for(int i = 0; i < group_count; i++)
{
sum_of_m += problem_size.Ms[i];
a0_tensors.push_back(Tensor<A0DataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A0Layout{})));
b_tensors.push_back(Tensor<B1DataType>(f_host_tensor_descriptor(
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{})));
b0_tensors.push_back(Tensor<B0DataType>(f_host_tensor_descriptor(
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{})));
b1_tensors.push_back(Tensor<B1DataType>(
f_host_tensor_descriptor(problem_size.Ks[i], problem_size.Ns[i], 0, B1Layout{})));
d0_tensors.push_back(Tensor<D0DataType>(
f_host_tensor_descriptor(problem_size.Ms[i], problem_size.Ns[i], 0, ELayout{})));
c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
c_device_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
std::cout << "gemm[" << i << "] a_m_k: " << a0_tensors[i].mDesc
<< " b_k_n: " << b0_tensors[i].mDesc << " d_m_n: " << d0_tensors[i].mDesc
<< " c_m_n: " << c_device_tensors[i].mDesc << std::endl;
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
num_btype += sizeof(A0DataType) * a0_tensors[i].mDesc.GetElementSize() +
sizeof(B0DataType) * b0_tensors[i].mDesc.GetElementSize() +
sizeof(B1DataType) * b1_tensors[i].mDesc.GetElementSize() +
sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSize() +
sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize();
switch(config.init_method)
{
case 0: break;
case 1:
a0_tensors[i].GenerateTensorValue(GeneratorTensor_2<A0DataType>{-5, 5});
b0_tensors[i].GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
b1_tensors[i].GenerateTensorValue(GeneratorTensor_2<B1DataType>{0, 5});
break;
case 2:
a0_tensors[i].GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_tensors[i].GenerateTensorValue(GeneratorTensor_3<B0DataType>{-5, 5});
b1_tensors[i].GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
break;
default:
a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
}
d0_tensors[i].GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
}
constexpr ck::index_t NumATensor = 1;
constexpr ck::index_t NumBTensor = 2;
constexpr ck::index_t NumDTensor = 1;
using GroupedGemmKernelArgument = ck::tensor_operation::device::
GroupedGemmMultiABDKernelArgument<NumATensor, NumBTensor, NumDTensor>;
std::vector<GroupedGemmKernelArgument> grouped_gemm_kernel_args_;
grouped_gemm_kernel_args_.reserve(group_count);
for(int i = 0; i < group_count; i++)
{
a0_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i]));
b0_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i]));
b1_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(B1DataType) * problem_size.Ns[i]));
d0_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(D0DataType) * problem_size.Ns[i]));
c_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(EDataType) * sum_of_m * problem_size.Ns[i]));
a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data(),
a0_tensors[i].mDesc.GetElementSpaceSize() *
sizeof(A0DataType));
b0_tensors_device[i]->ToDevice(b0_tensors[i].mData.data(),
b0_tensors[i].mDesc.GetElementSpaceSize() *
sizeof(B0DataType));
b1_tensors_device[i]->ToDevice(b1_tensors[i].mData.data(),
b1_tensors[i].mDesc.GetElementSpaceSize() *
sizeof(B1DataType));
d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data());
c_tensors_device[i]->SetZero();
gemm_descs.push_back(
{sum_of_m, problem_size.Ns[i], problem_size.Ks[i], {1}, {1, 1}, {0}, 1});
grouped_gemm_kernel_args_.push_back(
{std::array<const void*, NumATensor>{a0_tensors_device[i]->GetDeviceBuffer()},
std::array<const void*, NumBTensor>{b0_tensors_device[i]->GetDeviceBuffer(),
b1_tensors_device[i]->GetDeviceBuffer()},
std::array<const void*, NumDTensor>{d0_tensors_device[i]->GetDeviceBuffer()},
c_tensors_device[i]->GetDeviceBuffer(),
problem_size.Ms[i],
problem_size.Ns[i],
problem_size.Ks[i],
std::array<ck::index_t, NumATensor>{problem_size.stride_As[i]},
std::array<ck::index_t, NumBTensor>{problem_size.stride_Bs[i], 0},
std::array<ck::index_t, NumDTensor>{0},
problem_size.stride_Cs[i]});
}
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
std::vector<std::array<const void*, NumATensor>> p_As = {};
std::vector<std::array<const void*, NumBTensor>> p_Bs = {};
std::vector<std::array<const void*, NumDTensor>> p_Ds = {};
std::vector<void*> p_Cs = {};
// do GEMM
auto argument = gemm.MakeArgument(p_As, p_Bs, p_Ds, p_Cs, gemm_descs);
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer());
DeviceMem gemm_kernel_args_dev(gemm.GetDeviceKernelArgSize(&argument));
hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(),
grouped_gemm_kernel_args_.data(),
gemm.GetDeviceKernelArgSize(&argument),
hipMemcpyHostToDevice));
gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer());
gemm.SetKBatch(argument, config.k_batch);
gemm.SetElementwiseOps(argument, a_element_op, b_element_op, cde_element_op);
invoker.Run(argument, StreamConfig{nullptr, false});
if(config.time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
}
bool pass = true;
if(config.do_verification)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
B1DataType,
EDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
for(int n = 0; n < problem_size.Ns[i]; ++n)
{
for(int k = 0; k < problem_size.Ks[i]; ++k)
{
b_element_op(b_tensors[i](k, n), b0_tensors[i](k, n), b1_tensors[i](k, n));
}
}
c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(),
c_device_tensors[i].mDesc.GetElementSize() *
sizeof(EDataType));
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a0_tensors[i],
b_tensors[i],
c_host_tensors[i],
PassThrough{},
PassThrough{},
PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < problem_size.Ms[i]; ++m)
{
for(int n = 0; n < problem_size.Ns[i]; ++n)
{
cde_element_op(
c_host_tensors[i](m, n), c_host_tensors[i](m, n), d0_tensors[i](m, n));
}
}
pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]);
}
}
return pass;
}
int main(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
problem_size.group_count = 16;
for(int i = 0; i < problem_size.group_count; i++)
{
problem_size.Ms.push_back(32 + rand() % 32);
problem_size.Ns.push_back(1024);
problem_size.Ks.push_back(512);
problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
}
if(argc == 5)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
config.k_batch = std::stoi(argv[4]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4: k_batch (>0)\n");
exit(0);
}
return !run_grouped_gemm(problem_size, config);
}

View File

@@ -0,0 +1,397 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp"
#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
using Scale = ck::tensor_operation::element_wise::Scale;
using AddScale = ck::tensor_operation::element_wise::BinaryWithUnaryCombinedOp<Add, Scale, Scale>;
using A0DataType = F16;
using A1DataType = F32;
using AsDataType = ck::Tuple<A0DataType, A1DataType>;
using B0DataType = F16;
using BsDataType = ck::Tuple<B0DataType>;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F32;
using DsDataType = ck::Tuple<D0DataType>;
using EDataType = F32;
using A0Layout = Row;
using A1Layout = Row;
using AsLayout = ck::Tuple<A0Layout, A1Layout>;
using B0Layout = Col;
using BsLayout = ck::Tuple<B0Layout>;
using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>;
using ELayout = Row;
using AElementOp = AddScale;
using BElementOp = PassThrough;
using CDEElementOp = Add;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
// clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 128, 16, 128, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1, ck::half_t>;
// clang-format on
struct ProblemSize final
{
std::vector<ck::index_t> Ms;
std::vector<ck::index_t> Ns;
std::vector<ck::index_t> Ks;
std::vector<ck::index_t> stride_As;
std::vector<ck::index_t> stride_Bs;
std::vector<ck::index_t> stride_Cs;
ck::index_t group_count;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
int k_batch = 1;
};
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
auto group_count = problem_size.group_count;
// GEMM shape
std::vector<ck::tensor_operation::device::GemmMultiABDDesc> gemm_descs;
gemm_descs.reserve(group_count);
int sum_of_m = 0;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
std::vector<Tensor<A0DataType>> a0_tensors;
std::vector<Tensor<A1DataType>> a1_tensors;
std::vector<Tensor<B0DataType>> b_tensors;
std::vector<Tensor<D0DataType>> d0_tensors;
std::vector<Tensor<EDataType>> e_host_tensors;
std::vector<Tensor<EDataType>> e_device_tensors;
a0_tensors.reserve(group_count);
a1_tensors.reserve(group_count);
b_tensors.reserve(group_count);
d0_tensors.reserve(group_count);
e_host_tensors.reserve(group_count);
e_device_tensors.reserve(group_count);
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a0_tensors_device, a1_tensors_device, b_tensors_device,
d0_tensors_device, c_tensors_device;
a0_tensors_device.reserve(group_count);
a1_tensors_device.reserve(group_count);
b_tensors_device.reserve(group_count);
d0_tensors_device.reserve(group_count);
c_tensors_device.reserve(group_count);
std::size_t flop = 0, num_btype = 0;
for(int i = 0; i < group_count; i++)
{
sum_of_m += problem_size.Ms[i];
a0_tensors.push_back(Tensor<A0DataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A0Layout{})));
a1_tensors.push_back(Tensor<A1DataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], A1Layout{})));
b_tensors.push_back(Tensor<B0DataType>(f_host_tensor_descriptor(
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], B0Layout{})));
d0_tensors.push_back(Tensor<D0DataType>(
f_host_tensor_descriptor(problem_size.Ms[i], problem_size.Ns[i], 0, ELayout{})));
e_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
e_device_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
std::cout << "gemm[" << i << "] a_m_k: " << a0_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc << " d_m_n: " << d0_tensors[i].mDesc
<< " c_m_n: " << e_device_tensors[i].mDesc << std::endl;
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
num_btype += sizeof(A0DataType) * a0_tensors[i].mDesc.GetElementSize() +
sizeof(A1DataType) * a1_tensors[i].mDesc.GetElementSize() +
sizeof(B0DataType) * b_tensors[i].mDesc.GetElementSize() +
sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSize() +
sizeof(EDataType) * e_device_tensors[i].mDesc.GetElementSize();
switch(config.init_method)
{
case 0: break;
case 1:
a0_tensors[i].GenerateTensorValue(GeneratorTensor_2<A0DataType>{-5, 5});
a1_tensors[i].GenerateTensorValue(GeneratorTensor_2<A1DataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
break;
case 2:
a0_tensors[i].GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
a1_tensors[i].GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
break;
default:
a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
a1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
}
d0_tensors[i].GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
}
constexpr ck::index_t NumATensor = 2;
constexpr ck::index_t NumBTensor = 1;
constexpr ck::index_t NumDTensor = 1;
using GroupedGemmKernelArgument = ck::tensor_operation::device::
GroupedGemmMultiABDKernelArgument<NumATensor, NumBTensor, NumDTensor>;
std::vector<GroupedGemmKernelArgument> grouped_gemm_kernel_args_;
grouped_gemm_kernel_args_.reserve(group_count);
for(int i = 0; i < group_count; i++)
{
a0_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(A0DataType) * sum_of_m * problem_size.Ks[i]));
a1_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(A1DataType) * sum_of_m * problem_size.Ks[i]));
b_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(B0DataType) * problem_size.Ns[i] * problem_size.Ks[i]));
d0_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(D0DataType) * problem_size.Ns[i]));
c_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(EDataType) * sum_of_m * problem_size.Ns[i]));
a0_tensors_device[i]->ToDevice(a0_tensors[i].mData.data(),
a0_tensors[i].mDesc.GetElementSpaceSize() *
sizeof(A0DataType));
a1_tensors_device[i]->ToDevice(a1_tensors[i].mData.data(),
a1_tensors[i].mDesc.GetElementSpaceSize() *
sizeof(A1DataType));
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(),
b_tensors[i].mDesc.GetElementSpaceSize() *
sizeof(B0DataType));
d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data());
c_tensors_device[i]->SetZero();
gemm_descs.push_back({sum_of_m,
problem_size.Ns[i],
problem_size.Ks[i],
{1, 1},
{problem_size.stride_Bs[i]},
{0},
1});
grouped_gemm_kernel_args_.push_back(
{std::array<const void*, NumATensor>{a0_tensors_device[i]->GetDeviceBuffer(),
a1_tensors_device[i]->GetDeviceBuffer()},
std::array<const void*, NumBTensor>{b_tensors_device[i]->GetDeviceBuffer()},
std::array<const void*, NumDTensor>{d0_tensors_device[i]->GetDeviceBuffer()},
c_tensors_device[i]->GetDeviceBuffer(),
problem_size.Ms[i],
problem_size.Ns[i],
problem_size.Ks[i],
std::array<ck::index_t, NumATensor>{problem_size.stride_As[i],
problem_size.stride_As[i]},
std::array<ck::index_t, NumBTensor>{problem_size.stride_Bs[i]},
std::array<ck::index_t, NumDTensor>{0},
problem_size.stride_Cs[i]});
}
constexpr float scale = 1.f;
auto a_element_op = AElementOp{Add{}, Scale{scale}, Scale{scale}};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
std::vector<std::array<const void*, NumATensor>> p_As = {};
std::vector<std::array<const void*, NumBTensor>> p_Bs = {};
std::vector<std::array<const void*, NumDTensor>> p_Ds = {};
std::vector<void*> p_Cs = {};
// do GEMM
auto argument = gemm.MakeArgument(p_As, p_Bs, p_Ds, p_Cs, gemm_descs);
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer());
DeviceMem gemm_kernel_args_dev(gemm.GetDeviceKernelArgSize(&argument));
hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(),
grouped_gemm_kernel_args_.data(),
gemm.GetDeviceKernelArgSize(&argument),
hipMemcpyHostToDevice));
gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer());
gemm.SetKBatch(argument, config.k_batch);
gemm.SetElementwiseOps(argument, a_element_op, b_element_op, cde_element_op);
invoker.Run(argument, StreamConfig{nullptr, false});
if(config.time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
}
bool pass = true;
if(config.do_verification)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
B0DataType,
EDataType,
AccDataType,
PassThrough,
BElementOp,
PassThrough>;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
for(int m = 0; m < problem_size.Ms[i]; ++m)
{
for(int k = 0; k < problem_size.Ks[i]; ++k)
{
a_element_op(a0_tensors[i](m, k), a0_tensors[i](m, k), a1_tensors[i](m, k));
}
}
c_tensors_device[i]->FromDevice(e_device_tensors[i].mData.data(),
e_device_tensors[i].mDesc.GetElementSize() *
sizeof(EDataType));
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a0_tensors[i],
b_tensors[i],
e_host_tensors[i],
PassThrough{},
b_element_op,
PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < problem_size.Ms[i]; ++m)
{
for(int n = 0; n < problem_size.Ns[i]; ++n)
{
cde_element_op(
e_host_tensors[i](m, n), e_host_tensors[i](m, n), d0_tensors[i](m, n));
}
}
pass &= ck::utils::check_err(e_device_tensors[i], e_host_tensors[i]);
}
}
return pass;
}
int main(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
problem_size.group_count = 16;
for(int i = 0; i < problem_size.group_count; i++)
{
problem_size.Ms.push_back(32 + rand() % 32);
problem_size.Ns.push_back(1024);
problem_size.Ks.push_back(512);
problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
}
if(argc == 5)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
config.k_batch = std::stoi(argv[4]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4: k_batch (>0)\n");
exit(0);
}
return !run_grouped_gemm(problem_size, config);
}

View File

@@ -1 +1,2 @@
add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp)
add_example_executable(example_gemm_multi_ABD_xdl_bf16_i8 gemm_multi_ABD_xdl_bf16_i8.cpp)

View File

@@ -0,0 +1,270 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = BF16;
using AsDataType = ck::Tuple<A0DataType>;
using B0DataType = I8;
using B1DataType = BF16;
using BsDataType = ck::Tuple<B0DataType, B1DataType>;
using AccDataType = F32;
using CShuffleDataType = BF16;
using D0DataType = BF16;
using DsDataType = ck::Tuple<D0DataType>;
using EDataType = BF16;
using A0Layout = Row;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Col;
using B1Layout = B0Layout;
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
using D0Layout = Row;
using DsLayout = ck::Tuple<D0Layout>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using AElementOp = PassThrough;
using BElementOp = Scales;
using CDEElementOp = AddFastGelu;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl_CShuffle
// clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 16, 128, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 64;
ck::index_t N = 1024;
ck::index_t K = 512;
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideD = N;
ck::index_t StrideE = N;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 11)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideD = std::stoi(argv[9]);
StrideE = std::stoi(argv[10]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n");
exit(0);
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<B1DataType> b1_k_n(f_host_tensor_descriptor(K, N, 0, B1Layout{}));
Tensor<D0DataType> d_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl;
std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl;
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-5, 5});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
b1_k_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{0, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-5, 5});
break;
default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5});
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
}
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize());
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(D0DataType) * d_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a0_device_buf.ToDevice(a0_m_k.mData.data());
b0_device_buf.ToDevice(b0_k_n.mData.data());
b1_device_buf.ToDevice(b1_k_n.mData.data());
d_device_buf.ToDevice(d_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
constexpr ck::index_t NumATensor = 1;
constexpr ck::index_t NumBTensor = 2;
constexpr ck::index_t NumDTensor = 1;
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(std::array<const void*, NumATensor>{a0_device_buf.GetDeviceBuffer()},
std::array<const void*, NumBTensor>{b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer()},
std::array<const void*, NumDTensor>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
std::array<ck::index_t, NumATensor>{StrideA},
std::array<ck::index_t, NumBTensor>{StrideB, 0},
std::array<ck::index_t, NumDTensor>{StrideD},
StrideE,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_m_n({M, N});
Tensor<A0DataType> a_m_k({M, K});
Tensor<B1DataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K; ++k)
{
b_element_op(b_k_n(k, n), b0_k_n(k, n), b1_k_n(k, n));
}
}
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
B1DataType,
CShuffleDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a0_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n));
}
}
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
}
return 0;
}

View File

@@ -37,7 +37,7 @@ using DDataType = F16;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using BLayout = Row;
using DLayout = Row;
using ELayout = Row;
@@ -141,9 +141,9 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
1,
2,
8,
8,
1,
1,
1,
@@ -161,10 +161,10 @@ int main(int argc, char* argv[])
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideD = 4096;
ck::index_t StrideE = 4096;
ck::index_t StrideA = K;
ck::index_t StrideB = N;
ck::index_t StrideD = N;
ck::index_t StrideE = N;
float alpha = 1.0f;
float beta = 1.0f;

View File

@@ -102,7 +102,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceContractionMultiple
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
1,
8,
1,
S<4, 64, 1>,
@@ -131,7 +131,7 @@ int main(int argc, char* argv[])
std::vector<ck::index_t> a0_ms_ks_strides{128 * 32 * 64, 32 * 64, 64, 1};
// A1[M1, K1] -> A1[M0, M1, K0, K1]
std::vector<ck::index_t> a1_ms_ks_lengths{30, 128, 32, 64};
std::vector<ck::index_t> a1_ms_ks_strides{0, 64, 0, 1};
std::vector<ck::index_t> a1_ms_ks_strides{0, 64, 1, 0};
// B[N0, N1, K0, K1]
std::vector<ck::index_t> b_ns_ks_lengths{32, 64, 32, 64};
std::vector<ck::index_t> b_ns_ks_strides{64 * 32 * 64, 32 * 64, 64, 1};

View File

@@ -0,0 +1,98 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
struct GemmMultiABDDesc
{
ck::index_t M_, N_, K_;
std::vector<ck::index_t> stride_As_;
std::vector<ck::index_t> stride_Bs_;
std::vector<ck::index_t> stride_Ds_;
ck::index_t stride_C_;
};
/*
* \brief Grouped Gemm Multi ABD
*
* C = a_op(A, A1...) * b_op(B, B1...)
* E = cde_op(C, D0, D1, ...)
*
* \tparam AsLayout A layouts (tuple).
* \tparam BsLayout B layouts (tuple).
* \tparam DsLayout Ds layouts (tuple).
* \tparam ELayout Output layout.
* \tparam AsDataType A data types (tuple).
* \tparam BsDataType B data types (tuple).
* \tparam DsDataType D data types (tuple).
* \tparam EDataType Output data type.
* \tparam AElementwiseOperation A elementwise operation.
* \tparam BElementwiseOperation B elementwise operation.
* \tparam CDEElementwiseOperation C elementwise operation.
*/
template <typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGroupedGemmMultiABD : public BaseOperator
{
static constexpr index_t NumATensor = AsDataType::Size();
static constexpr index_t NumBTensor = BsDataType::Size();
static constexpr index_t NumDTensor = DsDataType::Size();
static_assert(AsLayout::Size() == AsDataType::Size(), "wrong! inconsistent NumATensor");
static_assert(BsLayout::Size() == BsDataType::Size(), "wrong! inconsistent NumBTensor");
static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsistent NumDTensor");
/*
* \brief Make argument pointer for grouped gemm multi abd.
*
* \param p_as A pointers to the A.
* \param p_bs A pointers to the B.
* \param p_ds A pointers to the Ds.
* \param p_e A pointers to the E.
* \param gemm_desc Gemm descriptors for each group.
* \param a_element_op A elementwise operation object.
* \param b_element_op B elementwise operation object.
* \param cde_element_op CDE elementwise operation object.
* \return Pointer to the argument.
*/
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<std::array<const void*, NumATensor>>& p_as,
std::vector<std::array<const void*, NumBTensor>>& p_bs,
std::vector<std::array<const void*, NumDTensor>>& p_ds,
std::vector<void*>& p_e,
std::vector<GemmMultiABDDesc>& gemm_desc,
AElementwiseOperation a_element_op = AElementwiseOperation{},
BElementwiseOperation b_element_op = BElementwiseOperation{},
CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual void SetElementwiseOps(BaseArgument* p_arg,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) const = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,81 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <array>
#include "device_grouped_gemm_multi_abd.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
struct GroupedGemmMultiABDKernelArgument
{
std::array<const void*, NumATensor> p_as_grid;
std::array<const void*, NumBTensor> p_bs_grid;
std::array<const void*, NumDTensor> p_ds_grid;
void* p_e_grid;
index_t M;
index_t N;
index_t K;
std::array<index_t, NumATensor> StrideAs;
std::array<index_t, NumBTensor> StrideBs;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
};
/*
* \brief Grouped Gemm Multi ABD Fixed NK
*
* C = a_op(A, A1...) * b_op(B, B1...)
* E = cde_op(C, D0, D1, ...)
*
* \tparam AsLayout A layouts (tuple).
* \tparam BsLayout B layouts (tuple).
* \tparam DsLayout Ds layouts (tuple).
* \tparam ELayout Output layout.
* \tparam AsDataType A data types (tuple).
* \tparam BsDataType B data types (tuple).
* \tparam DsDataType D data types (tuple).
* \tparam EDataType Output data type.
* \tparam AElementwiseOperation A elementwise operation.
* \tparam BElementwiseOperation B elementwise operation.
* \tparam CDEElementwiseOperation C elementwise operation.
*/
template <typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGroupedGemmMultiABDFixedNK : DeviceGroupedGemmMultiABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const = 0;
virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0;
virtual void SetKBatch(BaseArgument* p_arg, index_t k_batch) const = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -663,7 +663,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const bool valid_a_access_dim_k =
ABlockTransferSrcVectorDim == 2 && arg.as_kz_consecutive_[i];
const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k;
if(!(valid_a_vector_size && valid_a_access_dim))
if(!((valid_a_vector_size && valid_a_access_dim) ||
ABlockTransferSrcScalarPerVector == 1))
{
valid_as_access = false;
}
@@ -682,7 +683,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const bool valid_b_access_dim_k =
BBlockTransferSrcVectorDim == 2 && arg.bs_kz_consecutive_[i];
const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k;
if(!(valid_b_vector_size && valid_b_access_dim))
if(!((valid_b_vector_size && valid_b_access_dim) ||
BBlockTransferSrcScalarPerVector == 1))
{
valid_bs_access = false;
}
@@ -698,7 +700,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector read of Ds is always on N dimension.
const bool valid_d_access_dim = arg.ds_nz_consecutive_[i];
if(!(valid_d_vector_size && valid_d_access_dim))
if(!((valid_d_vector_size && valid_d_access_dim) ||
CDEBlockTransferScalarPerVector_NPerBlock == 1))
{
valid_ds_access = false;
}
@@ -712,7 +715,8 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector write of E is always on N dimension.
const bool valid_e_access_dim = arg.e_nz_consecutive_;
if(!(valid_e_vector_size && valid_e_access_dim))
if(!((valid_e_vector_size && valid_e_access_dim) ||
CDEBlockTransferScalarPerVector_NPerBlock == 1))
{
return false;
}

View File

@@ -169,78 +169,6 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
#if 0
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideAs)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, AsLayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideAs, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, AsLayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideAs));
}
}();
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideBs)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BsLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideBs));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BsLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideBs, I1));
}
}();
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
template <typename ELay>
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{
const auto e_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideE, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELay>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE));
}
}();
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
}
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
},
Number<NumDTensor>{});
}
#endif
using ComputeDataType = EDataType;
// GridwiseGemm
@@ -384,7 +312,7 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
// B desc
bs_grid_desc_n_k_(i) =
GridwiseGemm::template MakeBGridDescriptor_N_K<BLayout, GemmSpec>(
KRaw, NRaw, StrideBs[i]);
NRaw, KRaw, StrideBs[i]);
});
// populate pointer, desc for Ds
@@ -424,15 +352,6 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
}
}
void Print() const
{
// std::cout << "A[M, K]: " << as_grid_desc_m_k_ << std::endl;
// std::cout << "B[N, K]: " << bs_grid_desc_n_k_ << std::endl;
// static_for<0, NumDTensor, 1>{}(
//[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
// std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
}
// private:
// pointers
typename GridwiseGemm::AsGridPointer p_as_grid_;
@@ -578,7 +497,10 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
}
else
{
all_valid = false;
if(ABlockTransferSrcScalarPerVector != 1)
{
all_valid = false;
}
}
});
@@ -602,13 +524,15 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
}
else
{
all_valid = false;
if(BBlockTransferSrcScalarPerVector != 1)
{
all_valid = false;
}
}
});
// check vector load of Ds
// only support RowMajor for now
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
@@ -618,21 +542,21 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
}
});
if(!all_valid)
{
return false;
}
// check vector store of E
// only support RowMajor for now
if constexpr(is_same_v<ELayout, Row>)
{
if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
all_valid = false;
}
}
else
{
all_valid = false;
}
if(!all_valid)
{
return false;
}

View File

@@ -0,0 +1,851 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename GridwiseGemm,
typename GemmDesc,
GemmSpecialization GemmSpec,
typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename Block2ETileMap,
typename GroupedGemmBlock2ETileMap,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count,
const index_t grid_size_grp,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t KBatch = 1;
const index_t block_id = get_block_1d_id();
const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
const index_t group_id = block_id / grid_size_grp;
if(group_id >= group_count)
return;
const index_t M = gemm_desc_ptr[group_id].M;
const index_t N = gemm_desc_ptr[group_id].N;
const index_t K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0)
return;
const auto StrideAs = gemm_desc_ptr[group_id].StrideAs;
const auto StrideBs = gemm_desc_ptr[group_id].StrideBs;
const auto StrideDs = gemm_desc_ptr[group_id].StrideDs;
const auto StrideE = gemm_desc_ptr[group_id].StrideE;
const auto e_grid_desc_m_n =
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
const index_t BlockStart = group_id * grid_size_grp;
const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch};
const auto local_grid_size = local_b2e_tile_map.CalculateGridSize(e_grid_desc_m_n);
constexpr auto NumATensor = GridwiseGemm::AsGridPointer::Size();
constexpr auto NumBTensor = GridwiseGemm::BsGridPointer::Size();
constexpr auto NumDTensor = GridwiseGemm::DsGridPointer::Size();
typename GridwiseGemm::AsGridPointer p_as_grid_;
typename GridwiseGemm::BsGridPointer p_bs_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_;
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType = remove_cvref_t<decltype(p_as_grid_(i))>;
p_as_grid_(i) = static_cast<ADataType>(gemm_desc_ptr[group_id].p_as_grid[i]);
});
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType = remove_cvref_t<decltype(p_bs_grid_(i))>;
p_bs_grid_(i) = static_cast<BDataType>(gemm_desc_ptr[group_id].p_bs_grid[i]);
});
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<decltype(p_ds_grid_(i))>;
p_ds_grid_(i) = static_cast<DDataType>(gemm_desc_ptr[group_id].p_ds_grid[i]);
});
index_t id_off = 0;
index_t id_local = get_block_1d_id() - BlockStart;
while(id_local < local_grid_size)
{
const auto block_2_etile_map =
GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off);
GridwiseGemm::
template Run<HasMainKBlockLoop, GemmSpec, AsLayout, BsLayout, DsLayout, ELayout>(
p_as_grid_,
p_bs_grid_,
p_ds_grid_,
gemm_desc_ptr[group_id].p_e_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
M,
N,
K,
StrideAs,
StrideBs,
StrideDs,
StrideE,
block_2_etile_map);
id_off += grid_size_grp;
id_local += grid_size_grp;
}
#else
ignore = gemm_descs_const;
ignore = group_count;
ignore = grid_size_grp;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
#endif
}
template <typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename BsDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
ck::index_t NumPrefetch,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerXDL,
ck::index_t NPerXDL,
ck::index_t MXdlPerWave,
ck::index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
typename ComputeType = EDataType,
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
: public DeviceGroupedGemmMultiABDFixedNK<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK;
static constexpr index_t NumATensor = AsDataType::Size();
static constexpr index_t NumBTensor = BsDataType::Size();
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t NumGemmKPrefetchStage = 1;
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleABD_xdl_cshuffle<
AsDataType,
BsDataType,
ComputeType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
template <typename UnderlyingBlockToCTileMap>
struct OffsettedBlockToCTileMapMLoops
{
using underlying_type = UnderlyingBlockToCTileMap;
__host__ __device__ OffsettedBlockToCTileMapMLoops(
UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0)
{
block_to_ctile_map_ = block_to_ctile_map;
block_start_ = block_start;
id_off_ = id_off;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
auto idx_bot = block_to_ctile_map_.CalculateBottomIndex(
make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_));
return make_tuple(
// idx_bot[Number<0>{}],
idx_bot[Number<1>{}],
idx_bot[Number<2>{}]);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx,
const CTileDim& c_tile_dim) const
{
return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n);
}
template <typename CGridDesc_M_N>
__host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n);
}
UnderlyingBlockToCTileMap block_to_ctile_map_;
index_t block_start_;
index_t id_off_;
};
template <index_t MPerBlock_, index_t NPerBlock_>
struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default;
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default;
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default;
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&
operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default;
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&
operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default;
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M,
index_t N,
index_t KBatch,
index_t M01 = 8)
: M_(M), N_(N), KBatch_(KBatch), M01_(M01)
{
}
template <typename CGridDesc_M_N>
__host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8)
: BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01)
{
}
__host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const
{
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
return M0 * N0 * KBatch_;
}
template <typename CGridDesc_M_N>
__host__ __device__ constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
}
template <typename CGridDesc_M_N>
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
{
return true;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
auto block_1d_id = idx_top[I0];
const auto M0 = math::integer_divide_ceil(M_, MPerBlock_);
const auto N0 = math::integer_divide_ceil(N_, NPerBlock_);
block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups
const index_t idx_ksplit = block_1d_id / (M0 * N0);
block_1d_id = block_1d_id % (M0 * N0);
index_t idx_N0 = block_1d_id % N0;
index_t idx_M0 = block_1d_id / N0;
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
index_t idx_M00 = idx_M0 / M01_;
index_t idx_M01 = idx_M0 % M01_;
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
return make_tuple(idx_ksplit,
idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
idx_N0_M01_local / M01_adapt);
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const
{
return true; // always valid provided that user gets grid size from CalculateGridSize()
}
private:
index_t M_;
index_t N_;
index_t KBatch_;
index_t M01_;
};
using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops<MPerBlock, NPerBlock>;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops<Block2ETileMap>;
struct GemmBiasTransKernelArg
{
// pointers
std::array<const void*, NumATensor> as_ptr_;
std::array<const void*, NumBTensor> bs_ptr_;
std::array<const void*, NumDTensor> ds_ptr_;
void* e_ptr_;
index_t M_, N_, K_;
std::array<index_t, NumATensor> StrideAs_;
std::array<index_t, NumBTensor> StrideBs_;
std::array<index_t, NumDTensor> StrideDs_;
index_t StrideE_;
};
// Argument
struct Argument : public BaseArgument
{
void UpdateKBatch(index_t) {}
Argument(std::vector<std::array<const void*, NumATensor>>&,
std::vector<std::array<const void*, NumBTensor>>&,
std::vector<std::array<const void*, NumDTensor>>&,
std::vector<void*>&,
std::vector<GemmMultiABDDesc>& gemm_descs,
AElementwiseOperation a_element_op = AElementwiseOperation{},
BElementwiseOperation b_element_op = BElementwiseOperation{},
CDEElementwiseOperation c_element_op = CDEElementwiseOperation{})
: a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op}
{
grid_size_ = 0;
k_batch_ = 1;
grouped_gemm_kernel_args_dev = nullptr;
group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
gemm_desc_kernel_arg_.reserve(group_count_);
index_t group_id = 0;
sum_of_m = gemm_descs[0].M_;
const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_);
const index_t N = gemm_descs[0].N_;
const index_t K = gemm_descs[0].K_;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
if(sum_of_m != gemm_descs[i].M_ || N != gemm_descs[i].N_ || K != gemm_descs[i].K_)
{
throw std::runtime_error("wrong! M/N/K is not identical");
}
a_mtx_mraw_kraw_.emplace_back(sum_of_m, K);
b_mtx_nraw_kraw_.emplace_back(N, K);
// pointer
std::array<const void*, NumATensor> p_as_grid;
std::array<const void*, NumBTensor> p_bs_grid;
std::array<const void*, NumDTensor> p_ds_grid;
static_for<0, NumATensor, 1>{}([&](auto j) { p_as_grid[j] = nullptr; });
static_for<0, NumBTensor, 1>{}([&](auto j) { p_bs_grid[j] = nullptr; });
static_for<0, NumDTensor, 1>{}([&](auto j) { p_ds_grid[j] = nullptr; });
std::array<index_t, NumATensor> StrideAs;
std::array<index_t, NumBTensor> StrideBs;
std::array<index_t, NumDTensor> StrideDs;
const index_t StrideE = gemm_descs[i].stride_C_;
if(gemm_descs[i].stride_As_.size() != NumATensor)
{
throw std::runtime_error(
"wrong! gemm_descs[i].stride_As_.size() does not match NumATensor");
}
static_for<0, NumATensor, 1>{}(
[&](auto j) { StrideAs[j] = gemm_descs[i].stride_As_[j]; });
if(gemm_descs[i].stride_Bs_.size() != NumBTensor)
{
throw std::runtime_error(
"wrong! gemm_descs[i].stride_Bs_.size() does not match NumBTensor");
}
static_for<0, NumBTensor, 1>{}(
[&](auto j) { StrideBs[j] = gemm_descs[i].stride_Bs_[j]; });
if(gemm_descs[i].stride_Ds_.size() != NumDTensor)
{
throw std::runtime_error(
"wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor");
}
static_for<0, NumDTensor, 1>{}(
[&](auto j) { StrideDs[j] = gemm_descs[i].stride_Ds_[j]; });
const auto e_grid_desc_m_n =
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(
AverM, N, StrideE);
// block-to-e-tile map
const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_};
grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n);
if(group_id * grid_size_grp_ != grid_size_)
{
throw std::runtime_error("wrong! grid_size_grp_ is not identical!");
}
grid_size_ += grid_size_grp_;
// check block-to-E-tile
if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n))
{
throw std::runtime_error("wrong! block_2_etile_map validation failed");
}
gemm_desc_kernel_arg_.push_back(GemmBiasTransKernelArg{
p_as_grid,
p_bs_grid,
p_ds_grid,
nullptr,
AverM,
N,
K,
StrideAs,
StrideBs,
StrideDs,
StrideE,
});
group_id++;
}
const auto e_grid_desc_sum_m_n =
GridwiseGemm::template MakeEGridDescriptor_M_N<ELayout, GemmSpec>(
sum_of_m, gemm_desc_kernel_arg_[0].N_, gemm_desc_kernel_arg_[0].StrideE_);
const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, 1};
barrier_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n);
}
// private:
index_t group_count_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation c_element_op_;
std::vector<GemmBiasTransKernelArg> gemm_desc_kernel_arg_;
std::vector<Tuple<index_t, index_t>> a_mtx_mraw_kraw_;
std::vector<Tuple<index_t, index_t>> b_mtx_nraw_kraw_;
const void* grouped_gemm_kernel_args_dev;
index_t grid_size_;
index_t grid_size_grp_;
index_t barrier_size_grp_;
index_t sum_of_m;
index_t k_batch_ = 1;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
bool has_main_k_block_loop = true;
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{
if(GridwiseGemm::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K_) !=
has_main_k_block_loop)
{
throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
}
}
if(arg.grouped_gemm_kernel_args_dev == nullptr)
{
throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr");
}
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_, auto e_global_memory_operation_) {
const auto kernel = kernel_grouped_gemm_xdl_fixed_nk<
GridwiseGemm,
GroupedGemmMultiABDKernelArgument<NumATensor, NumBTensor, NumDTensor>,
GemmSpec,
AsLayout,
BsLayout,
DsLayout,
ELayout,
Block2ETileMap,
GroupedGemmBlock2ETileMap,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
e_global_memory_operation_,
has_main_k_block_loop_>;
return launch_and_time_kernel(
stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev),
arg.gemm_desc_kernel_arg_.size(),
arg.grid_size_grp_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
};
constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd;
constexpr auto Set = InMemoryDataOperationEnum::Set;
if(arg.k_batch_ > 1)
{
if(has_main_k_block_loop)
{
ave_time =
launch_kernel(integral_constant<bool, true>{},
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
}
else
{
ave_time =
launch_kernel(integral_constant<bool, false>{},
integral_constant<InMemoryDataOperationEnum, AtomicAdd>{});
}
}
else
{
if(has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, true>{},
integral_constant<InMemoryDataOperationEnum, Set>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{},
integral_constant<InMemoryDataOperationEnum, Set>{});
}
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_)
{
return false;
}
bool supported = true;
// If we use padding we do not support vector loads for dimensions not divisible by vector
// load size.
if constexpr(GemmSpec != GemmSpecialization::Default)
{
// [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout,
// thus we have to adapt it to the {M,K} or {N,K} layout.
const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0;
const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0;
for(index_t i = 0; i < arg.group_count_; ++i)
{
const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number<a_raw_vector_dim>{});
const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number<b_raw_vector_dim>{});
supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0);
supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0);
}
}
return supported;
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(std::vector<std::array<const void*, NumATensor>>& p_As,
std::vector<std::array<const void*, NumBTensor>>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmMultiABDDesc> gemm_descs,
AElementwiseOperation a_element_op = AElementwiseOperation{},
BElementwiseOperation b_element_op = BElementwiseOperation{},
CDEElementwiseOperation c_element_op = CDEElementwiseOperation{})
{
return Argument{
p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<std::array<const void*, NumATensor>>& p_As,
std::vector<std::array<const void*, NumBTensor>>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>& p_Ds,
std::vector<void*>& p_Es,
std::vector<GemmMultiABDDesc>& gemm_descs,
AElementwiseOperation a_element_op = AElementwiseOperation{},
BElementwiseOperation b_element_op = BElementwiseOperation{},
CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) override
{
return std::make_unique<Argument>(
p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedGemm_Xdl_Fixed_NK"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< getGemmSpecializationString(GemmSpec)
<< ">";
// clang-format on
return str.str();
}
static void SetElementwiseOps(Argument& arg,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation c_element_op)
{
arg.a_element_op_ = a_element_op;
arg.b_element_op_ = b_element_op;
arg.c_element_op_ = c_element_op;
}
static void SetDeviceKernelArgs(Argument& arg, const void* kernel_args)
{
arg.grouped_gemm_kernel_args_dev = kernel_args;
}
// polymorphic
void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const override
{
return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), kernel_args);
}
void SetElementwiseOps(BaseArgument* p_arg,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation c_element_op) const override
{
SetElementwiseOps(
*dynamic_cast<Argument*>(p_arg), a_element_op, b_element_op, c_element_op);
}
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
{
auto arg = *dynamic_cast<const Argument*>(p_arg);
return arg.group_count_ *
sizeof(GroupedGemmMultiABDKernelArgument<NumATensor, NumBTensor, NumDTensor>);
}
#if 0
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
auto arg = *dynamic_cast<const Argument*>(p_arg);
return arg.group_count_ * arg.barrier_size_grp_ * sizeof(uint32_t);
}
void SetWorkSpacePointer(BaseArgument* p_arg,
void* p_workspace,
const StreamConfig& stream_config = StreamConfig{}) const override
{
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
p_arg_->p_workspace_ = p_workspace;
hip_check_error(
hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(p_arg), stream_config.stream_id_));
}
#endif
static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); }
// polymorphic
void SetKBatch(BaseArgument* p_arg, index_t k_batch) const override
{
return SetKBatch(*dynamic_cast<Argument*>(p_arg), k_batch);
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -11,7 +11,6 @@
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
namespace ck {

View File

@@ -4,7 +4,7 @@
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
@@ -92,6 +92,15 @@ struct Add
};
};
struct Scales
{
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const
{
y = ck::type_convert<Y>(ck::type_convert<float>(x0) * ck::type_convert<float>(x1));
}
};
struct Max
{
template <typename Y, typename X0, typename X1>
@@ -485,6 +494,19 @@ struct AddFastGelu
e = type_convert<half_t>(x1_f);
}
template <>
__host__ __device__ constexpr void
operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& e, const bhalf_t& c, const bhalf_t& d) const
{
const float x0_f = type_convert<float>(c) + type_convert<float>(d);
float x1_f = 0;
FastGelu{}.template operator()<float, float>(x1_f, x0_f);
e = type_convert<bhalf_t>(x1_f);
}
template <>
__host__ __device__ constexpr void
operator()<bhalf_t, float, bhalf_t>(bhalf_t& e, const float& c, const bhalf_t& d) const

View File

@@ -14,6 +14,8 @@ namespace element_wise {
template <typename... UnaryOpsSet>
struct UnaryCombinedOp
{
__host__ __device__ UnaryCombinedOp() : unary_ops_() {}
__host__ __device__ UnaryCombinedOp(UnaryOpsSet... unary_ops) : unary_ops_(unary_ops...) {}
template <typename Y, typename X>
@@ -32,6 +34,8 @@ struct UnaryCombinedOp
template <typename BinaryOp, typename UnaryOp0, typename UnaryOp1>
struct BinaryWithUnaryCombinedOp
{
__host__ __device__ BinaryWithUnaryCombinedOp() : binary_op_(), unary_op0_(), unary_op1_() {}
__host__ __device__ BinaryWithUnaryCombinedOp(BinaryOp binary_op,
UnaryOp0 unary_op0,
UnaryOp1 unary_op1)
@@ -63,6 +67,11 @@ template <typename BinaryOp0,
typename UnaryOp2>
struct TrinaryWithUnaryCombinedOp
{
__host__ __device__ TrinaryWithUnaryCombinedOp()
: binary_op0_(), binary_op1_(), unary_op0_(), unary_op1_(), unary_op2_()
{
}
__host__ __device__ TrinaryWithUnaryCombinedOp(BinaryOp0 binary_op0,
BinaryOp0 binary_op1,
UnaryOp0 unary_op0,

View File

@@ -288,10 +288,13 @@ struct ConvertF8RNE
struct Scale
{
__host__ __device__ Scale(float scale) : scale_(scale) {}
__host__ __device__ Scale(float scale = 1.f) : scale_(scale) {}
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
__host__ __device__ void operator()(Y& y, const X& x) const
{
y = ck::type_convert<Y>(ck::type_convert<float>(x) * scale_);
}
template <>
__host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
@@ -500,6 +503,36 @@ struct FastGelu
y = type_convert<half_t>(y_f);
}
template <>
__device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<bhalf_t>(y_f);
}
template <>
__device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<bhalf_t>(y_f);
}
template <>
__host__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<bhalf_t>(y_f);
}
};
// https://paperswithcode.com/method/gelu

View File

@@ -439,7 +439,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
template <typename BLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
MakeBGridDescriptor_N_K(const index_t NRaw, const index_t KRaw, const index_t StrideB)
{
constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
@@ -463,15 +463,15 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
template <typename BsLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeBsGridDescriptor_N_K(const std::array<index_t, NumBTensor>& KRaws,
const std::array<index_t, NumBTensor>& NRaws,
MakeBsGridDescriptor_N_K(const std::array<index_t, NumBTensor>& NRaws,
const std::array<index_t, NumBTensor>& KRaws,
const std::array<index_t, NumBTensor>& BsStride)
{
return generate_tuple(
[&](auto i) {
using BLayout = remove_cvref_t<tuple_element_t<i.value, BsLayout>>;
return MakeBGridDescriptor_N_K<BLayout, GemmSpec>(KRaws[i], NRaws[i], BsStride[i]);
return MakeBGridDescriptor_N_K<BLayout, GemmSpec>(NRaws[i], KRaws[i], BsStride[i]);
},
Number<NumBTensor>{});
}
@@ -574,7 +574,6 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
{
return;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
@@ -595,8 +594,10 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); },
Number<NumATensor>{});
#if 0
static_assert(ABlockTransferSrcScalarPerVector == ABlockTransferDstScalarPerVector_AK1,
"Src and Dst ScalarPerVector must be the same");
#endif
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,
@@ -626,8 +627,10 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); },
Number<NumBTensor>{});
#if 0
static_assert(BBlockTransferSrcScalarPerVector == BBlockTransferDstScalarPerVector_BK1,
"Src and Dst ScalarPerVector must be the same");
#endif
auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
ThisThreadBlock,

View File

@@ -10,38 +10,9 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp"
namespace ck {
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// and sometimes useless instructions:
// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument
// instead
// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same
// tensor coordinate instead
// 3. Don't use a pointer to VGPR buffer, use vector instead
namespace detail {
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor
template <index_t VectorDim, index_t ScalarPerVector>
struct lambda_scalar_per_access
{
__host__ __device__ constexpr auto operator()(index_t i) const
{
return (i == VectorDim) ? ScalarPerVector : 1;
}
};
template <index_t VectorDim>
struct lambda_scalar_step_in_vector
{
__host__ __device__ constexpr auto operator()(index_t i) const
{
return (i == VectorDim) ? 1 : 0;
}
};
} // namespace detail
// Assume:
// 1. src:
// 1. SrcDesc is known at compile-time

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// and sometimes useless instructions:
// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument
// instead
// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same
// tensor coordinate instead
// 3. Don't use a pointer to VGPR buffer, use vector instead
namespace detail {
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor
template <index_t VectorDim, index_t ScalarPerVector>
struct lambda_scalar_per_access
{
__host__ __device__ constexpr auto operator()(index_t i) const
{
return (i == VectorDim) ? ScalarPerVector : 1;
}
};
template <index_t VectorDim>
struct lambda_scalar_step_in_vector
{
__host__ __device__ constexpr auto operator()(index_t i) const
{
return (i == VectorDim) ? 1 : 0;
}
};
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor
template <index_t SrcVectorDim,
index_t SrcScalarPerVector,
index_t DstVectorDim,
index_t DstScalarPerVector>
struct lambda_scalar_per_access_for_src_and_dst
{
__host__ __device__ constexpr auto operator()(index_t i) const
{
if(i == SrcVectorDim && i == DstVectorDim)
{
return math::lcm(SrcScalarPerVector, DstScalarPerVector);
}
else if(i == SrcVectorDim)
{
return SrcScalarPerVector;
}
else if(i == DstVectorDim)
{
return DstScalarPerVector;
}
else
{
return 1;
}
}
};
} // namespace detail
} // namespace ck

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -7,44 +7,13 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp"
#include "ck/utility/is_detected.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp"
namespace ck {
namespace detail {
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor
template <index_t SrcVectorDim,
index_t SrcScalarPerVector,
index_t DstVectorDim,
index_t DstScalarPerVector>
struct lambda_scalar_per_access_for_src_and_dst
{
__host__ __device__ constexpr auto operator()(index_t i) const
{
if(i == SrcVectorDim && i == DstVectorDim)
{
return math::lcm(SrcScalarPerVector, DstScalarPerVector);
}
else if(i == SrcVectorDim)
{
return SrcScalarPerVector;
}
else if(i == DstVectorDim)
{
return DstScalarPerVector;
}
else
{
return 1;
}
}
};
} // namespace detail
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. SrcBuffer and DstBuffer are DynamicBuffer

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -8,9 +8,11 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include "ck/utility/is_detected.hpp"
#include "ck/tensor/static_tensor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp"
namespace ck {
// Thread-level multi-source, multi-destination tensor slice data movement
// Assume:
// 1. All sources and destinations are DynamicBuffer
@@ -70,16 +72,18 @@ struct ThreadwiseTensorSliceTransfer_v7r2
static constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
using SrcSpaceFillingCurve = SpaceFillingCurve<SliceLengths,
SrcDimAccessOrder,
remove_cv_t<decltype(src_scalar_per_access)>>;
static constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
using SrcSpaceFillingCurve = SpaceFillingCurve<SliceLengths,
SrcDimAccessOrder,
remove_cv_t<decltype(src_scalar_per_access)>,
false>;
using DstSpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DstDimAccessOrder,
remove_cv_t<decltype(dst_scalar_per_access)>>;
remove_cv_t<decltype(dst_scalar_per_access)>,
false>;
__device__ constexpr ThreadwiseTensorSliceTransfer_v7r2(
const SrcDescs& src_descs,
@@ -139,9 +143,9 @@ struct ThreadwiseTensorSliceTransfer_v7r2
__device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs)
{
// loop over space-filling curve
static_for<0, num_access, 1>{}([&](auto iAccess) {
static_for<0, src_num_access, 1>{}([&](auto iAccess) {
auto src_vectors = generate_vectors<SrcDatas, SrcScalarPerVector>();
auto dst_vectors = generate_vectors<DstDatas, DstScalarPerVector>();
auto elm_vectors = generate_vectors<DstDatas, SrcScalarPerVector>();
// copy data from src_bufs into src_vectors
static_for<0, nSrc, 1>{}([&](auto i) {
@@ -199,7 +203,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
using elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
return dst_vectors(iDst).template AsType<elem_op_vec_t>()(i);
return elm_vectors(iDst).template AsType<elem_op_vec_t>()(i);
},
Number<nDst>{});
@@ -214,10 +218,10 @@ struct ThreadwiseTensorSliceTransfer_v7r2
unpack2(element_op_, dst_data_refs, src_data_refs);
});
dst_vectors_tuple_(iAccess) = dst_vectors;
elm_vectors_tuple_(iAccess) = elm_vectors;
// move coordinate
if constexpr(iAccess.value != num_access - 1)
if constexpr(iAccess.value != src_num_access - 1)
{
constexpr auto forward_step = SrcSpaceFillingCurve::GetForwardStep(iAccess);
@@ -241,15 +245,113 @@ struct ThreadwiseTensorSliceTransfer_v7r2
});
}
__device__ void TransposeFromElmToDst()
{
using DstData = remove_cvref_t<decltype(DstDatas{}[I0])>;
using SrcThreadScratch =
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData,
SrcScalarPerVector,
decltype(GetSrcThreadScratchDescriptor()),
true>;
using DstThreadScratch =
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData,
DstScalarPerVector,
decltype(GetDstThreadScratchDescriptor()),
true>;
SrcThreadScratch elm_thread_scratch_;
DstThreadScratch dst_thread_scratch_;
elm_thread_scratch_.data_ =
bit_cast<decltype(elm_thread_scratch_.data_)>(elm_vectors_tuple_);
if constexpr(SrcVectorDim != DstVectorDim &&
((is_same<half_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
(is_same<int8_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
{
// each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
constexpr index_t num_src_vector = Number<DstScalarPerVector>{};
constexpr index_t num_dst_vector = Number<SrcScalarPerVector>{};
// Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose
// TODO: make this logic generic for all scenario
constexpr auto src_scalar_step_in_vector = generate_sequence(
detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
constexpr auto dst_scalar_step_in_vector = generate_sequence(
detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access_for_src_and_dst<SrcVectorDim,
SrcScalarPerVector,
DstVectorDim,
DstScalarPerVector>{},
Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
static_ford<decltype(access_lengths)>{}([&](auto access_idx) {
constexpr auto data_idx = access_idx * scalar_per_access;
constexpr auto data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<data_idx[i]>{}; }, Number<nDim>{});
using src_vector_t = vector_type_maker_t<DstData, SrcScalarPerVector>;
using dst_vector_t = vector_type_maker_t<DstData, DstScalarPerVector>;
// get DstScalarPerVector # of read-only references to src vectors from
// src_thread_scratch_
const auto src_vector_refs = generate_tie(
[&](auto i) -> const src_vector_t& {
// i increment corresponds to movement in DstVectorDim
return elm_thread_scratch_.GetVectorTypeReference(
data_idx_seq + i * dst_scalar_step_in_vector);
},
Number<num_src_vector>{});
// get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_
auto dst_vector_refs = generate_tie(
[&](auto i) -> dst_vector_t& {
// i increment corresponds to movement in SrcVectorDim
return dst_thread_scratch_.GetVectorTypeReference(
data_idx_seq + i * src_scalar_step_in_vector);
},
Number<num_dst_vector>{});
// do data transpose
transpose_vectors<DstData, DstScalarPerVector, SrcScalarPerVector>{}(
src_vector_refs, dst_vector_refs);
});
}
else
{
static_ford<SliceLengths>{}(
[&](auto idx) { dst_thread_scratch_(idx) = elm_thread_scratch_[idx]; });
}
dst_vectors_tuple_ = bit_cast<decltype(dst_vectors_tuple_)>(dst_thread_scratch_.data_);
}
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
template <typename DstBuffers,
enable_if_t<DstDescs::Size() == DstBuffers::Size(), bool> = false>
enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1, bool> = false>
__device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs)
{
TransposeFromElmToDst();
// loop over space-filling curve
static_for<0, num_access, 1>{}([&](auto iAccess) {
auto dst_vectors = dst_vectors_tuple_[iAccess];
static_for<0, dst_num_access, 1>{}([&](auto iAccess) {
auto dst_vectors = dst_vectors_tuple_[Number<iAccess>{}];
// copy data from buf_vectors into dst_bufs
static_for<0, nDst, 1>{}([&](auto i) {
@@ -269,7 +371,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
});
// move coordinate
if constexpr(iAccess.value != num_access - 1)
if constexpr(iAccess.value != dst_num_access - 1)
{
constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess);
@@ -312,28 +414,126 @@ struct ThreadwiseTensorSliceTransfer_v7r2
__device__ static constexpr auto GetSrcCoordinateResetStep()
{
if constexpr(num_access == 0)
if constexpr(src_num_access == 0)
{
return typename SrcSpaceFillingCurve::Index{};
}
else
{
return SrcSpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
return SrcSpaceFillingCurve::GetStepBetween(Number<src_num_access - 1>{}, Number<0>{});
}
}
__device__ static constexpr auto GetDstCoordinateResetStep()
{
if constexpr(num_access == 0)
if constexpr(dst_num_access == 0)
{
return typename DstSpaceFillingCurve::Index{};
}
else
{
return DstSpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
return DstSpaceFillingCurve::GetStepBetween(Number<dst_num_access - 1>{}, Number<0>{});
}
}
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
{
// constexpr auto src_scalar_per_access = generate_sequence(
// detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(src_access_lengths), Number<SrcScalarPerVector>{});
// 1st stage of transforms
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(src_access_lengths_and_vector_length[i],
src_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(src_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
__device__ static constexpr auto GetDstThreadScratchDescriptor()
{
// 1st stage of transforms
// constexpr auto dst_scalar_per_access = generate_sequence(
// detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(dst_access_lengths), Number<DstScalarPerVector>{});
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == DstVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(dst_access_lengths_and_vector_length[i],
dst_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(dst_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == DstVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template <index_t ISrc>
__device__ void MoveSrcSliceWindow(const SrcDescs& src_descs,
@@ -372,11 +572,14 @@ struct ThreadwiseTensorSliceTransfer_v7r2
private:
using SrcVectorsType = decltype(generate_vectors<SrcDatas, SrcScalarPerVector>());
using ElmVectorsType = decltype(generate_vectors<DstDatas, SrcScalarPerVector>());
using DstVectorsType = decltype(generate_vectors<DstDatas, DstScalarPerVector>());
static constexpr auto num_access = SrcSpaceFillingCurve::GetNumOfAccess();
static constexpr auto src_num_access = SrcSpaceFillingCurve::GetNumOfAccess();
static constexpr auto dst_num_access = DstSpaceFillingCurve::GetNumOfAccess();
StaticallyIndexedArray<DstVectorsType, num_access> dst_vectors_tuple_;
StaticallyIndexedArray<ElmVectorsType, src_num_access> elm_vectors_tuple_;
StaticallyIndexedArray<DstVectorsType, dst_num_access> dst_vectors_tuple_;
SrcCoords src_coords_;
DstCoords dst_coords_;

View File

@@ -0,0 +1,468 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using Scales = ck::tensor_operation::element_wise::Scales;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
#ifdef CK_ENABLE_INT8
// RRR
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<Row>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
AddFastGelu>>>& instances);
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<Row>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
Add>>>& instances);
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<>,
BF16,
PassThrough,
Scales,
FastGelu>>>& instances);
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<>,
BF16,
PassThrough,
Scales,
PassThrough>>>& instances);
// RCR
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Row>,
ck::Tuple<Col, Col>,
ck::Tuple<Row>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
AddFastGelu>>>& instances);
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Row>,
ck::Tuple<Col, Col>,
ck::Tuple<Row>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
Add>>>& instances);
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Row>,
ck::Tuple<Col, Col>,
ck::Tuple<>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<>,
BF16,
PassThrough,
Scales,
FastGelu>>>& instances);
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Row>,
ck::Tuple<Col, Col>,
ck::Tuple<>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<>,
BF16,
PassThrough,
Scales,
PassThrough>>>& instances);
// CRR
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Col>,
ck::Tuple<Row, Row>,
ck::Tuple<Row>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
AddFastGelu>>>& instances);
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_bias_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Col>,
ck::Tuple<Row, Row>,
ck::Tuple<Row>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
Add>>>& instances);
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Col>,
ck::Tuple<Row, Row>,
ck::Tuple<>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<>,
BF16,
PassThrough,
Scales,
FastGelu>>>& instances);
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<ck::Tuple<Col>,
ck::Tuple<Row, Row>,
ck::Tuple<>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<>,
BF16,
PassThrough,
Scales,
PassThrough>>>& instances);
#endif
// GEMM + Add + Gelu
template <typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
Scales,
AddFastGelu>>
{
using DeviceOp = DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
Scales,
AddFastGelu>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
is_same_v<DsDataType, ck::Tuple<BF16>> && is_same_v<EDataType, BF16>)
{
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
{
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances(op_ptrs);
}
if constexpr(is_same_v<AsLayout, ck::Tuple<Col>> &&
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
{
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_v1_instances(op_ptrs);
}
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Col, Col>> &&
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
{
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_v1_instances(op_ptrs);
}
}
#endif
return op_ptrs;
}
};
// GEMM + Add
template <typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
Scales,
Add>>
{
using DeviceOp = DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
Scales,
Add>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
is_same_v<DsDataType, ck::Tuple<BF16>> && is_same_v<EDataType, BF16>)
{
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
{
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances(op_ptrs);
}
if constexpr(is_same_v<AsLayout, ck::Tuple<Col>> &&
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
{
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_bias_v1_instances(op_ptrs);
}
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Col, Col>> &&
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
{
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_v1_instances(op_ptrs);
}
}
#endif
return op_ptrs;
}
};
// GEMM + Gelu
template <typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
Scales,
FastGelu>>
{
using DeviceOp = DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
Scales,
FastGelu>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
is_same_v<DsDataType, ck::Tuple<>> && is_same_v<EDataType, BF16>)
{
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
{
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances(op_ptrs);
}
if constexpr(is_same_v<AsLayout, ck::Tuple<Col>> &&
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
{
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_v1_instances(op_ptrs);
}
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Col, Col>> &&
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
{
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_v1_instances(op_ptrs);
}
}
#endif
return op_ptrs;
}
};
// GEMM
template <typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
Scales,
PassThrough>>
{
using DeviceOp = DeviceGemmMultipleABD<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
Scales,
PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
is_same_v<DsDataType, ck::Tuple<>> && is_same_v<EDataType, BF16>)
{
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
{
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances(op_ptrs);
}
if constexpr(is_same_v<AsLayout, ck::Tuple<Col>> &&
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
{
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_v1_instances(op_ptrs);
}
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Col, Col>> &&
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
{
add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_v1_instances(op_ptrs);
}
}
#endif
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,470 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using Scales = ck::tensor_operation::element_wise::Scales;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
// RRR
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<Row>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
AddFastGelu>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<Row>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
Add>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<>,
BF16,
PassThrough,
Scales,
FastGelu>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
ck::Tuple<Row, Row>,
ck::Tuple<>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<>,
BF16,
PassThrough,
Scales,
PassThrough>>>& instances);
// RCR
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
ck::Tuple<Col, Col>,
ck::Tuple<Row>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
AddFastGelu>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
ck::Tuple<Col, Col>,
ck::Tuple<Row>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
Add>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
ck::Tuple<Col, Col>,
ck::Tuple<>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<>,
BF16,
PassThrough,
Scales,
FastGelu>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Row>,
ck::Tuple<Col, Col>,
ck::Tuple<>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<>,
BF16,
PassThrough,
Scales,
PassThrough>>>& instances);
// CRR
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Col>,
ck::Tuple<Row, Row>,
ck::Tuple<Row>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
AddFastGelu>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Col>,
ck::Tuple<Row, Row>,
ck::Tuple<Row>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<BF16>,
BF16,
PassThrough,
Scales,
Add>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Col>,
ck::Tuple<Row, Row>,
ck::Tuple<>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<>,
BF16,
PassThrough,
Scales,
FastGelu>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<ck::Tuple<Col>,
ck::Tuple<Row, Row>,
ck::Tuple<>,
Row,
ck::Tuple<BF16>,
ck::Tuple<I8, BF16>,
ck::Tuple<>,
BF16,
PassThrough,
Scales,
PassThrough>>>& instances);
// GEMM + Add + Gelu
template <typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGroupedGemmMultiABDFixedNK<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
Scales,
AddFastGelu>>
{
using DeviceOp = DeviceGroupedGemmMultiABDFixedNK<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
Scales,
AddFastGelu>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
is_same_v<DsDataType, ck::Tuple<BF16>> && is_same_v<EDataType, BF16>)
{
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances(
op_ptrs);
}
if constexpr(is_same_v<AsLayout, ck::Tuple<Col>> &&
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances(
op_ptrs);
}
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Col, Col>> &&
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances(
op_ptrs);
}
}
return op_ptrs;
}
};
// GEMM + Add
template <typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGroupedGemmMultiABDFixedNK<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
Scales,
Add>>
{
using DeviceOp = DeviceGroupedGemmMultiABDFixedNK<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
Scales,
Add>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
is_same_v<DsDataType, ck::Tuple<BF16>> && is_same_v<EDataType, BF16>)
{
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances(
op_ptrs);
}
if constexpr(is_same_v<AsLayout, ck::Tuple<Col>> &&
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances(
op_ptrs);
}
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Col, Col>> &&
is_same_v<DsLayout, ck::Tuple<Row>> && is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances(
op_ptrs);
}
}
return op_ptrs;
}
};
// GEMM + Gelu
template <typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGroupedGemmMultiABDFixedNK<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
Scales,
FastGelu>>
{
using DeviceOp = DeviceGroupedGemmMultiABDFixedNK<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
Scales,
FastGelu>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
is_same_v<DsDataType, ck::Tuple<>> && is_same_v<EDataType, BF16>)
{
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances(
op_ptrs);
}
if constexpr(is_same_v<AsLayout, ck::Tuple<Col>> &&
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances(
op_ptrs);
}
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Col, Col>> &&
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances(
op_ptrs);
}
}
return op_ptrs;
}
};
// GEMM
template <typename AsLayout,
typename BsLayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename BsDataType,
typename DsDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGroupedGemmMultiABDFixedNK<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
Scales,
PassThrough>>
{
using DeviceOp = DeviceGroupedGemmMultiABDFixedNK<AsLayout,
BsLayout,
DsLayout,
ELayout,
AsDataType,
BsDataType,
DsDataType,
EDataType,
PassThrough,
Scales,
PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<AsDataType, ck::Tuple<BF16>> &&
is_same_v<BsDataType, ck::Tuple<I8, BF16>> &&
is_same_v<DsDataType, ck::Tuple<>> && is_same_v<EDataType, BF16>)
{
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances(
op_ptrs);
}
if constexpr(is_same_v<AsLayout, ck::Tuple<Col>> &&
is_same_v<BsLayout, ck::Tuple<Row, Row>> &&
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances(
op_ptrs);
}
if constexpr(is_same_v<AsLayout, ck::Tuple<Row>> &&
is_same_v<BsLayout, ck::Tuple<Col, Col>> &&
is_same_v<DsLayout, ck::Tuple<>> && is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances(
op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,10 @@
# ONLY XDL_KERNELS
set(GEMM_MULTI_ABD_INSTANCES)
list(APPEND GEMM_MULTI_ABD_INSTANCES
device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_km_kn_mn_v1_instance.cpp
)
add_instance_library(device_gemm_multi_abd_instance ${GEMM_MULTI_ABD_INSTANCES})

View File

@@ -0,0 +1,101 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = BF16;
using AsDataType = ck::Tuple<A0DataType>;
using B0DataType = I8;
using B1DataType = BF16;
using BsDataType = ck::Tuple<B0DataType, B1DataType>;
using AccDataType = F32;
using CShuffleDataType = BF16;
using D0DataType = BF16;
// using DsDataType = ck::Tuple<D0DataType>;
using EDataType = BF16;
using A0Layout = Col;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Row;
using B1Layout = B0Layout;
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
using D0Layout = Row;
// using DsLayout = ck::Tuple<D0Layout>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using Add = ck::tensor_operation::element_wise::Add;
using AElementOp = PassThrough;
using BElementOp = Scales;
// using CDEElementOp = AddFastGelu;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
template <typename DsLayout,
typename DsDataType,
typename CDEElementOp,
ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::PipelineVersion PipVer,
ck::LoopScheduler LoopSche>
using device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_instances = std::tuple<
// clang-format off
//###############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| K0Per| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//###############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//###############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//###############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//PipelineVersion::v1
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 48, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 24, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 64, 32, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 64, 16, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 4, LoopSche, PipVer>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,101 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = BF16;
using AsDataType = ck::Tuple<A0DataType>;
using B0DataType = I8;
using B1DataType = BF16;
using BsDataType = ck::Tuple<B0DataType, B1DataType>;
using AccDataType = F32;
using CShuffleDataType = BF16;
using D0DataType = BF16;
// using DsDataType = ck::Tuple<D0DataType>;
using EDataType = BF16;
using A0Layout = Row;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Row;
using B1Layout = B0Layout;
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
using D0Layout = Row;
// using DsLayout = ck::Tuple<D0Layout>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using Add = ck::tensor_operation::element_wise::Add;
using AElementOp = PassThrough;
using BElementOp = Scales;
// using CDEElementOp = AddFastGelu;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
template <typename DsLayout,
typename DsDataType,
typename CDEElementOp,
ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::PipelineVersion PipVer,
ck::LoopScheduler LoopSche>
using device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_instances = std::tuple<
// clang-format off
//###############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| K0Per| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//###############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//###############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//###############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//PipelineVersion::v1
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 48, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 24, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 64, 32, 32, 32, 8, 8, 32, 32, 1, 1, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 64, 16, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 4, LoopSche, PipVer>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,101 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = BF16;
using AsDataType = ck::Tuple<A0DataType>;
using B0DataType = I8;
using B1DataType = BF16;
using BsDataType = ck::Tuple<B0DataType, B1DataType>;
using AccDataType = F32;
using CShuffleDataType = BF16;
using D0DataType = BF16;
// using DsDataType = ck::Tuple<D0DataType>;
using EDataType = BF16;
using A0Layout = Row;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Col;
using B1Layout = B0Layout;
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
using D0Layout = Row;
// using DsLayout = ck::Tuple<D0Layout>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using Add = ck::tensor_operation::element_wise::Add;
using AElementOp = PassThrough;
using BElementOp = Scales;
// using CDEElementOp = AddFastGelu;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
template <typename DsLayout,
typename DsDataType,
typename CDEElementOp,
ck::tensor_operation::device::GemmSpecialization GemmSpec,
ck::PipelineVersion PipVer,
ck::LoopScheduler LoopSche>
using device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_instances = std::tuple<
// clang-format off
//###############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| K0Per| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//###############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//###############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//###############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//PipelineVersion::v1
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 48, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 24, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 64, 32, 32, 32, 8, 8, 32, 32, 1, 1, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, LoopSche, PipVer>,
DeviceGemmMultipleABD_Xdl_CShuffle< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 64, 16, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 4, LoopSche, PipVer>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,115 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
BsLayout,
ck::Tuple<D0Layout>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<D0DataType>,
EDataType,
AElementOp,
BElementOp,
AddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_instances<ck::Tuple<D0Layout>,
ck::Tuple<D0DataType>,
AddFastGelu,
GemmMNKPadding,
PipelineVersion::v1,
LoopScheduler::Default>{});
}
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_bias_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
BsLayout,
ck::Tuple<D0Layout>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<D0DataType>,
EDataType,
AElementOp,
BElementOp,
Add>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_instances<ck::Tuple<D0Layout>,
ck::Tuple<D0DataType>,
Add,
GemmMNKPadding,
PipelineVersion::v1,
LoopScheduler::Default>{});
}
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
BsLayout,
ck::Tuple<>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<>,
EDataType,
AElementOp,
BElementOp,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_instances<ck::Tuple<>,
ck::Tuple<>,
PassThrough,
GemmMNKPadding,
PipelineVersion::v1,
LoopScheduler::Default>{});
}
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
BsLayout,
ck::Tuple<>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<>,
EDataType,
AElementOp,
BElementOp,
FastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_km_kn_mn_instances<ck::Tuple<>,
ck::Tuple<>,
FastGelu,
GemmMNKPadding,
PipelineVersion::v1,
LoopScheduler::Default>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,115 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
BsLayout,
ck::Tuple<D0Layout>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<D0DataType>,
EDataType,
AElementOp,
BElementOp,
AddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_instances<ck::Tuple<D0Layout>,
ck::Tuple<D0DataType>,
AddFastGelu,
GemmMNKPadding,
PipelineVersion::v1,
LoopScheduler::Default>{});
}
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
BsLayout,
ck::Tuple<D0Layout>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<D0DataType>,
EDataType,
AElementOp,
BElementOp,
Add>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_instances<ck::Tuple<D0Layout>,
ck::Tuple<D0DataType>,
Add,
GemmMNKPadding,
PipelineVersion::v1,
LoopScheduler::Default>{});
}
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
BsLayout,
ck::Tuple<>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<>,
EDataType,
AElementOp,
BElementOp,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_instances<ck::Tuple<>,
ck::Tuple<>,
PassThrough,
GemmMNKPadding,
PipelineVersion::v1,
LoopScheduler::Default>{});
}
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
BsLayout,
ck::Tuple<>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<>,
EDataType,
AElementOp,
BElementOp,
FastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_instances<ck::Tuple<>,
ck::Tuple<>,
FastGelu,
GemmMNKPadding,
PipelineVersion::v1,
LoopScheduler::Default>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,115 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
BsLayout,
ck::Tuple<D0Layout>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<D0DataType>,
EDataType,
AElementOp,
BElementOp,
AddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_instances<ck::Tuple<D0Layout>,
ck::Tuple<D0DataType>,
AddFastGelu,
GemmMNKPadding,
PipelineVersion::v1,
LoopScheduler::Default>{});
}
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
BsLayout,
ck::Tuple<D0Layout>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<D0DataType>,
EDataType,
AElementOp,
BElementOp,
Add>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_instances<ck::Tuple<D0Layout>,
ck::Tuple<D0DataType>,
Add,
GemmMNKPadding,
PipelineVersion::v1,
LoopScheduler::Default>{});
}
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
BsLayout,
ck::Tuple<>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<>,
EDataType,
AElementOp,
BElementOp,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_instances<ck::Tuple<>,
ck::Tuple<>,
PassThrough,
GemmMNKPadding,
PipelineVersion::v1,
LoopScheduler::Default>{});
}
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_v1_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
BsLayout,
ck::Tuple<>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<>,
EDataType,
AElementOp,
BElementOp,
FastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_instances<ck::Tuple<>,
ck::Tuple<>,
FastGelu,
GemmMNKPadding,
PipelineVersion::v1,
LoopScheduler::Default>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,10 @@
# ONLY XDL_KERNELS
set(GROUPED_GEMM_FIXED_NK_MULTI_ABD_INSTANCES)
list(APPEND GROUPED_GEMM_FIXED_NK_MULTI_ABD_INSTANCES
device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bias_gelu_bf16_i8_bf16_km_kn_mn_instance.cpp
)
add_instance_library(device_grouped_gemm_fixed_nk_multi_abd_instance ${GROUPED_GEMM_FIXED_NK_MULTI_ABD_INSTANCES})

View File

@@ -0,0 +1,89 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = BF16;
using AsDataType = ck::Tuple<A0DataType>;
using B0DataType = I8;
using B1DataType = BF16;
using BsDataType = ck::Tuple<B0DataType, B1DataType>;
using AccDataType = F32;
using CShuffleDataType = BF16;
using D0DataType = BF16;
// using DsDataType = ck::Tuple<D0DataType>;
using EDataType = BF16;
using A0Layout = Col;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Row;
using B1Layout = B0Layout;
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
using D0Layout = Row;
// using DsLayout = ck::Tuple<Row>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using Add = ck::tensor_operation::element_wise::Add;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AElementOp = PassThrough;
using BElementOp = Scales;
// using CDEElementOp = AddFastGelu;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <typename DsLayout,
typename DsDataType,
typename CDEElementOp,
ck::tensor_operation::device::GemmSpecialization GemmSpec>
using device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances = std::tuple<
// clang-format off
//######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM|NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization|Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 2>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,89 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = BF16;
using AsDataType = ck::Tuple<A0DataType>;
using B0DataType = I8;
using B1DataType = BF16;
using BsDataType = ck::Tuple<B0DataType, B1DataType>;
using AccDataType = F32;
using CShuffleDataType = BF16;
using D0DataType = BF16;
// using DsDataType = ck::Tuple<D0DataType>;
using EDataType = BF16;
using A0Layout = Row;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Row;
using B1Layout = B0Layout;
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
using D0Layout = Row;
// using DsLayout = ck::Tuple<Row>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using Add = ck::tensor_operation::element_wise::Add;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AElementOp = PassThrough;
using BElementOp = Scales;
// using CDEElementOp = AddFastGelu;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <typename DsLayout,
typename DsDataType,
typename CDEElementOp,
ck::tensor_operation::device::GemmSpecialization GemmSpec>
using device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances = std::tuple<
// clang-format off
//######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM|NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization|Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,89 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = BF16;
using AsDataType = ck::Tuple<A0DataType>;
using B0DataType = I8;
using B1DataType = BF16;
using BsDataType = ck::Tuple<B0DataType, B1DataType>;
using AccDataType = F32;
using CShuffleDataType = BF16;
using D0DataType = BF16;
// using DsDataType = ck::Tuple<D0DataType>;
using EDataType = BF16;
using A0Layout = Row;
using AsLayout = ck::Tuple<A0Layout>;
using B0Layout = Col;
using B1Layout = B0Layout;
using BsLayout = ck::Tuple<B0Layout, B1Layout>;
using D0Layout = Row;
// using DsLayout = ck::Tuple<Row>;
using ELayout = Row;
using Scales = ck::tensor_operation::element_wise::Scales;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using Add = ck::tensor_operation::element_wise::Add;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AElementOp = PassThrough;
using BElementOp = Scales;
// using CDEElementOp = AddFastGelu;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <typename DsLayout,
typename DsDataType,
typename CDEElementOp,
ck::tensor_operation::device::GemmSpecialization GemmSpec>
using device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances = std::tuple<
// clang-format off
//######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM|NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization|Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,111 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_km_kn_mn_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<AsLayout,
BsLayout,
ck::Tuple<D0Layout>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<D0DataType>,
EDataType,
AElementOp,
BElementOp,
AddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances<
ck::Tuple<D0Layout>,
ck::Tuple<D0DataType>,
AddFastGelu,
GemmMNKPadding>{});
}
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_bias_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<AsLayout,
BsLayout,
ck::Tuple<D0Layout>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<D0DataType>,
EDataType,
AElementOp,
BElementOp,
Add>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances<
ck::Tuple<D0Layout>,
ck::Tuple<D0DataType>,
Add,
GemmMNKPadding>{});
}
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<AsLayout,
BsLayout,
ck::Tuple<>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<>,
EDataType,
AElementOp,
BElementOp,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances<
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
GemmMNKPadding>{});
}
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<AsLayout,
BsLayout,
ck::Tuple<>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<>,
EDataType,
AElementOp,
BElementOp,
FastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_km_kn_mn_instances<
ck::Tuple<>,
ck::Tuple<>,
FastGelu,
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,111 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<AsLayout,
BsLayout,
ck::Tuple<D0Layout>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<D0DataType>,
EDataType,
AElementOp,
BElementOp,
AddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances<
ck::Tuple<D0Layout>,
ck::Tuple<D0DataType>,
AddFastGelu,
GemmMNKPadding>{});
}
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<AsLayout,
BsLayout,
ck::Tuple<D0Layout>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<D0DataType>,
EDataType,
AElementOp,
BElementOp,
Add>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances<
ck::Tuple<D0Layout>,
ck::Tuple<D0DataType>,
Add,
GemmMNKPadding>{});
}
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<AsLayout,
BsLayout,
ck::Tuple<>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<>,
EDataType,
AElementOp,
BElementOp,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances<
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
GemmMNKPadding>{});
}
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<AsLayout,
BsLayout,
ck::Tuple<>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<>,
EDataType,
AElementOp,
BElementOp,
FastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_kn_mn_instances<
ck::Tuple<>,
ck::Tuple<>,
FastGelu,
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,111 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<AsLayout,
BsLayout,
ck::Tuple<D0Layout>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<D0DataType>,
EDataType,
AElementOp,
BElementOp,
AddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances<
ck::Tuple<D0Layout>,
ck::Tuple<D0DataType>,
AddFastGelu,
GemmMNKPadding>{});
}
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<AsLayout,
BsLayout,
ck::Tuple<D0Layout>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<D0DataType>,
EDataType,
AElementOp,
BElementOp,
Add>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances<
ck::Tuple<D0Layout>,
ck::Tuple<D0DataType>,
Add,
GemmMNKPadding>{});
}
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<AsLayout,
BsLayout,
ck::Tuple<>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<>,
EDataType,
AElementOp,
BElementOp,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances<
ck::Tuple<>,
ck::Tuple<>,
PassThrough,
GemmMNKPadding>{});
}
void add_device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmMultiABDFixedNK<AsLayout,
BsLayout,
ck::Tuple<>,
ELayout,
AsDataType,
BsDataType,
ck::Tuple<>,
EDataType,
AElementOp,
BElementOp,
FastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_gemm_xdl_fixed_nk_multi_abd_bf16_i8_bf16_mk_nk_mn_instances<
ck::Tuple<>,
ck::Tuple<>,
FastGelu,
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck