mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Merge branch 'develop' into ck_tile/gemm_blockscale_abquant
This commit is contained in:
3
.github/scripts/therock_configure_ci.py
vendored
3
.github/scripts/therock_configure_ci.py
vendored
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import fnmatch
|
||||
import json
|
||||
import os
|
||||
|
||||
4
.github/workflows/therock-ci-linux.yml
vendored
4
.github/workflows/therock-ci-linux.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
||||
permissions:
|
||||
id-token: write
|
||||
container:
|
||||
image: ghcr.io/rocm/therock_build_manylinux_x86_64@sha256:2f3ebd0beb04c449fdb36933e54bdc69483b914fb9005594d3fc9444c206b54b
|
||||
image: ghcr.io/rocm/therock_build_manylinux_x86_64@sha256:583d473f263a289222c48d4b493e2956b2354a45796f09dee6f2c8ecd4504ab6
|
||||
options: -v /runner/config:/home/awsconfig/
|
||||
env:
|
||||
AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }}
|
||||
@@ -54,7 +54,7 @@ jobs:
|
||||
with:
|
||||
repository: "ROCm/TheRock"
|
||||
path: "TheRock"
|
||||
ref: f3f77a3161922df3eee006b888b439d75b2b4668 # 2025-10-29 commit
|
||||
ref: d76278526218def9fb1b016bc9e421738cb4f8f6 # 2025-12-09 commit
|
||||
|
||||
- name: Setup ccache
|
||||
run: |
|
||||
|
||||
2
.github/workflows/therock-ci.yml
vendored
2
.github/workflows/therock-ci.yml
vendored
@@ -65,7 +65,7 @@ jobs:
|
||||
-DTHEROCK_USE_EXTERNAL_ROCM_LIBRARIES=ON
|
||||
-DTHEROCK_ROCM_LIBRARIES_SOURCE_DIR=../
|
||||
amdgpu_families: "gfx94X-dcgpu"
|
||||
test_runs_on: "linux-mi325-1gpu-ossci-rocm"
|
||||
test_runs_on: "linux-mi325-1gpu-ossci-rocm-frac"
|
||||
|
||||
therock_ci_summary:
|
||||
name: TheRock CI Summary
|
||||
|
||||
2
.github/workflows/therock-test-component.yml
vendored
2
.github/workflows/therock-test-component.yml
vendored
@@ -51,7 +51,7 @@ jobs:
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
with:
|
||||
repository: "ROCm/TheRock"
|
||||
ref: f3f77a3161922df3eee006b888b439d75b2b4668 # 2025-10-29 commit
|
||||
ref: d76278526218def9fb1b016bc9e421738cb4f8f6 # 2025-12-09 commit
|
||||
|
||||
- name: Run setup test environment workflow
|
||||
uses: './.github/actions/setup_test_environment'
|
||||
|
||||
2
.github/workflows/therock-test-packages.yml
vendored
2
.github/workflows/therock-test-packages.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
repository: "ROCm/TheRock"
|
||||
ref: f3f77a3161922df3eee006b888b439d75b2b4668 # 2025-10-29 commit
|
||||
ref: d76278526218def9fb1b016bc9e421738cb4f8f6 # 2025-12-09 commit
|
||||
|
||||
- name: "Configuring CI options"
|
||||
env:
|
||||
|
||||
@@ -20,12 +20,12 @@ repos:
|
||||
)$
|
||||
- repo: local
|
||||
hooks:
|
||||
# - id: copyright-year-checker
|
||||
# name: copyright-year-checker
|
||||
# entry: script/check_copyright_year.sh
|
||||
# verbose: false
|
||||
# language: script
|
||||
# types: [c++]
|
||||
- id: copyright-header-checker
|
||||
name: Check copyright headers
|
||||
entry: script/check_copyright_year.sh
|
||||
verbose: false
|
||||
language: script
|
||||
types_or: [c++, python, shell, cmake]
|
||||
- id: remove-exec-bit
|
||||
name: Remove executable bit from non-executable files
|
||||
entry: script/remove_exec_bit.sh
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Script to convert all mermaid diagrams in CK Tile docs to SVGs.
|
||||
This script:
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""Convert raw HTML mermaid blocks to commented format for SVG conversion."""
|
||||
|
||||
import os
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Helper script to update SVG diagrams from commented mermaid sources in RST files.
|
||||
|
||||
|
||||
@@ -77,3 +77,5 @@ example_compile_options(example_moe_gemm1_xdl_fp8_blockscale PRIVATE ${BLOCKSCAL
|
||||
add_example_executable(example_gemm_add_add_wmma_fp16 gemm_add_add_wmma_fp16.cpp)
|
||||
add_example_executable(example_gemm_multiply_multiply_wmma_fp16_bpreshuffle gemm_multiply_multiply_wmma_fp16_bpreshuffle.cpp)
|
||||
add_example_executable(example_gemm_multiply_multiply_wmma_fp8_bpreshuffle gemm_multiply_multiply_wmma_fp8_bpreshuffle.cpp)
|
||||
add_example_executable(example_gemm_multiply_multiply_wmma_fp8_ab_scale gemm_multiply_multiply_wmma_fp8_ab_scale.cpp)
|
||||
add_example_executable(example_gemm_multiply_multiply_wmma_fp8_blockscale_bpreshuffle gemm_multiply_multiply_wmma_fp8_blockscale_bpreshuffle.cpp)
|
||||
|
||||
@@ -0,0 +1,345 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_ab_scale.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
|
||||
#include "ck/utility/blkgemmpipe_scheduler.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using FP8 = ck::f8_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using A0DataType = FP8;
|
||||
using A1DataType = F32;
|
||||
using B0DataType = FP8;
|
||||
using B1DataType = F32;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using EDataType = BF16;
|
||||
|
||||
using A0Layout = Row;
|
||||
using B0Layout = Col;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = Row;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
static constexpr ck::index_t Scale_Block_M = 1;
|
||||
static constexpr ck::index_t Scale_Block_N = 128;
|
||||
static constexpr ck::index_t Scale_Block_K = 128;
|
||||
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3
|
||||
// clang-format off
|
||||
<Row, Col, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType,
|
||||
AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
256, Scale_Block_M, Scale_Block_N, Scale_Block_K,
|
||||
128, 128, 128,
|
||||
16, 16,
|
||||
16, 16,
|
||||
4, 2,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
1, 1, S<1, 32, 1, 8>, S<8>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
bool flush_cache = true;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 128;
|
||||
ck::index_t N = 1024;
|
||||
ck::index_t K = 1024;
|
||||
|
||||
ck::index_t StrideA = K;
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideE = N;
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 8 || argc == 9)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
flush_cache = std::stoi(argv[7]);
|
||||
|
||||
if(argc == 9)
|
||||
{
|
||||
KBatch = std::stoi(argv[8]);
|
||||
}
|
||||
|
||||
StrideA = K;
|
||||
StrideB = K;
|
||||
StrideE = N;
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 6: M, N, K\n");
|
||||
printf("arg7: flush both I$ and L2$ (0=no, 1=yes)\n");
|
||||
printf("arg8: KBatch (default: 1)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K;
|
||||
ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return ck::HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck::HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
ck::Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
|
||||
ck::Tensor<A1DataType> a1_m_k(f_host_tensor_descriptor((M + Scale_Block_M - 1) / Scale_Block_M,
|
||||
(K + Scale_Block_K - 1) / Scale_Block_K,
|
||||
Scale_Stride_AM,
|
||||
A0Layout{}));
|
||||
ck::Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
|
||||
ck::Tensor<B1DataType> b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K,
|
||||
(N + Scale_Block_N - 1) / Scale_Block_N,
|
||||
Scale_Stride_BN,
|
||||
B0Layout{}));
|
||||
ck::Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
ck::Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
|
||||
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
|
||||
std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl;
|
||||
std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl;
|
||||
std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl;
|
||||
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_m_k.GenerateTensorValue(GeneratorTensor_2<A1DataType>{-1, 1});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-1, 1});
|
||||
break;
|
||||
case 2:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_m_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 3:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_m_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 4:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_m_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 5:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_m_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
break;
|
||||
default:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
a1_m_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
}
|
||||
|
||||
ck::DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
|
||||
ck::DeviceMem a1_device_buf(sizeof(A1DataType) * a1_m_k.mDesc.GetElementSpaceSize());
|
||||
ck::DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize());
|
||||
ck::DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize());
|
||||
ck::DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a0_device_buf.ToDevice(a0_m_k.mData.data());
|
||||
a1_device_buf.ToDevice(a1_m_k.mData.data());
|
||||
b0_device_buf.ToDevice(b0_k_n.mData.data());
|
||||
b1_device_buf.ToDevice(b1_k_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
std::string op_name = device_op.GetTypeString();
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument =
|
||||
device_op.MakeArgument(static_cast<A0DataType*>(a0_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
|
||||
std::array<const void*, NumDTensor>{},
|
||||
static_cast<EDataType*>(e_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<ck::index_t, NumDTensor>{},
|
||||
StrideE,
|
||||
static_cast<const A1DataType*>(a1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<const B1DataType*>(b1_device_buf.GetDeviceBuffer()),
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
KBatch);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
|
||||
|
||||
float ave_time = .0;
|
||||
|
||||
ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0, 50, 100});
|
||||
|
||||
int pass = 0;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
ck::Tensor<AccDataType> c_m_n({M, N});
|
||||
ck::Tensor<float> a_m_k({M, K});
|
||||
ck::Tensor<float> b_k_n({K, N});
|
||||
|
||||
for(int m = 0; m < M; m++)
|
||||
{
|
||||
for(int k = 0; k < K; k++)
|
||||
{
|
||||
a_m_k(m, k) = ck::type_convert<float>(a0_m_k(m, k)) *
|
||||
a1_m_k(m / Scale_Block_M, k / Scale_Block_K);
|
||||
}
|
||||
}
|
||||
|
||||
for(int n = 0; n < N; n++)
|
||||
{
|
||||
for(int k = 0; k < K; k++)
|
||||
{
|
||||
b_k_n(k, n) = ck::type_convert<float>(b0_k_n(k, n)) *
|
||||
b1_k_n(k / Scale_Block_K, n / Scale_Block_N);
|
||||
}
|
||||
}
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<float,
|
||||
float,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument =
|
||||
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
e_m_n_host_result(m, n) = ck::type_convert<EDataType>(c_m_n(m, n));
|
||||
}
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
pass = ck::utils::check_err(
|
||||
e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 5e-2, 5e-2)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
|
||||
if(flush_cache)
|
||||
{
|
||||
int rotating_buf = (512 * 1024 * 1024 + num_btype - 1) / num_btype;
|
||||
|
||||
ave_time = invoker.Run(argument,
|
||||
StreamConfig{nullptr, time_kernel, 0, 50, 100, true, rotating_buf});
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 100});
|
||||
}
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< op_name << ", KBatch " << KBatch << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
@@ -0,0 +1,357 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_blockscale_bpreshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
|
||||
#include "ck/utility/blkgemmpipe_scheduler.hpp"
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using FP8 = ck::f8_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using A0DataType = FP8;
|
||||
using A1DataType = F32;
|
||||
using B0DataType = FP8;
|
||||
using B1DataType = F32;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using EDataType = BF16;
|
||||
|
||||
using A0Layout = Row;
|
||||
using A1Layout = Col;
|
||||
using B0Layout = Col;
|
||||
using D0Layout = Row;
|
||||
using D1Layout = Col;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = Row;
|
||||
|
||||
static constexpr int KPack = 16;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
static constexpr ck::index_t Scale_Block_M = 1;
|
||||
static constexpr ck::index_t Scale_Block_N = 128;
|
||||
static constexpr ck::index_t Scale_Block_K = 128;
|
||||
|
||||
using DeviceOpInstance =
|
||||
ck::tensor_operation::device::DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle
|
||||
// clang-format off
|
||||
<Row, Col, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
256, Scale_Block_M, Scale_Block_N, Scale_Block_K,
|
||||
128, 128, 128,
|
||||
16, 16,
|
||||
16, 16,
|
||||
4, 2,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 16, 16, 0,
|
||||
1, 1,
|
||||
S<1, 32, 1, 8>, S<8>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
bool flush_cache = true;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 128;
|
||||
ck::index_t N = 1024;
|
||||
ck::index_t K = 1024;
|
||||
|
||||
ck::index_t StrideA = K;
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideE = N;
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 8 || argc == 9)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
flush_cache = std::stoi(argv[7]);
|
||||
|
||||
if(argc == 9)
|
||||
{
|
||||
KBatch = std::stoi(argv[8]);
|
||||
}
|
||||
|
||||
StrideA = K;
|
||||
StrideB = K;
|
||||
StrideE = N;
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 6: M, N, K\n");
|
||||
printf("arg7: flush both I$ and L2$ (0=no, 1=yes)\n");
|
||||
printf("arg8: KBatch (default: 1)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
// Transpose the AScale tensor for better performance
|
||||
ck::index_t Scale_Stride_AK = (M + Scale_Block_M - 1) / Scale_Block_M;
|
||||
ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return ck::HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck::HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
ck::Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
|
||||
ck::Tensor<A1DataType> a1_m_k(f_host_tensor_descriptor((M + Scale_Block_M - 1) / Scale_Block_M,
|
||||
(K + Scale_Block_K - 1) / Scale_Block_K,
|
||||
Scale_Stride_AK,
|
||||
A1Layout{}));
|
||||
ck::Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
|
||||
ck::Tensor<B0DataType> b0_preshuffled(
|
||||
f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size
|
||||
ck::Tensor<B1DataType> b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K,
|
||||
(N + Scale_Block_N - 1) / Scale_Block_N,
|
||||
Scale_Stride_BN,
|
||||
B0Layout{}));
|
||||
ck::Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
ck::Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
|
||||
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
|
||||
std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl;
|
||||
std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl;
|
||||
std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl;
|
||||
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_m_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
break;
|
||||
case 2:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_m_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 3:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_m_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 4:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_m_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 5:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_m_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
break;
|
||||
default:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
a1_m_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
}
|
||||
|
||||
ck::DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
|
||||
ck::DeviceMem a1_device_buf(sizeof(A1DataType) * a1_m_k.mDesc.GetElementSpaceSize());
|
||||
ck::DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize());
|
||||
ck::DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize());
|
||||
ck::DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a0_device_buf.ToDevice(a0_m_k.mData.data());
|
||||
a1_device_buf.ToDevice(a1_m_k.mData.data());
|
||||
b1_device_buf.ToDevice(b1_k_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
std::string op_name = device_op.GetTypeString();
|
||||
int NPerWmma = device_op.GetPreShuffleParameters();
|
||||
|
||||
preShuffleBuffer<KPack>(b0_k_n.mData.data(), b0_preshuffled.mData.data(), N, K, NPerWmma);
|
||||
|
||||
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument = device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(),
|
||||
b0_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, NumDTensor>{},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<ck::index_t, NumDTensor>{},
|
||||
StrideE,
|
||||
a1_device_buf.GetDeviceBuffer(),
|
||||
b1_device_buf.GetDeviceBuffer(),
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
KBatch);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
|
||||
|
||||
float ave_time = 0.0f;
|
||||
|
||||
if(flush_cache)
|
||||
{
|
||||
int rotating_buf = (512 * 1024 * 1024 + num_btype - 1) / num_btype;
|
||||
|
||||
ave_time = invoker.Run(argument,
|
||||
StreamConfig{nullptr, time_kernel, 0, 50, 100, true, rotating_buf});
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 100});
|
||||
}
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< op_name << ", KBatch " << KBatch << std::endl;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
ck::Tensor<AccDataType> c_m_n({M, N});
|
||||
ck::Tensor<float> a_m_k({M, K});
|
||||
ck::Tensor<float> b_k_n({K, N});
|
||||
|
||||
for(int m = 0; m < M; m++)
|
||||
{
|
||||
for(int k = 0; k < K; k++)
|
||||
{
|
||||
a_m_k(m, k) = ck::type_convert<float>(a0_m_k(m, k)) *
|
||||
a1_m_k(m / Scale_Block_M, k / Scale_Block_K);
|
||||
}
|
||||
}
|
||||
|
||||
for(int n = 0; n < N; n++)
|
||||
{
|
||||
for(int k = 0; k < K; k++)
|
||||
{
|
||||
b_k_n(k, n) = ck::type_convert<float>(b0_k_n(k, n)) *
|
||||
b1_k_n(k / Scale_Block_K, n / Scale_Block_N);
|
||||
}
|
||||
}
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<float,
|
||||
float,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument =
|
||||
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
e_m_n_host_result(m, n) = ck::type_convert<EDataType>(c_m_n(m, n));
|
||||
}
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
return ck::utils::check_err(
|
||||
e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 5e-2, 5e-2)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -770,7 +770,7 @@ def create_kernel(
|
||||
|
||||
class CompatibilityRuleFactory:
|
||||
@staticmethod
|
||||
def get_rules() -> list[CompatibilityRule]:
|
||||
def get_rules() -> List[CompatibilityRule]:
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
def check_mode(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
|
||||
if problem_ctx.mode == "group":
|
||||
@@ -812,7 +812,7 @@ class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory):
|
||||
_AVAILABLE_PIPELINES = frozenset({"qr", "qr_async", "qs"})
|
||||
|
||||
@classmethod
|
||||
def get_rules(cls) -> list[CompatibilityRule]:
|
||||
def get_rules(cls) -> List[CompatibilityRule]:
|
||||
rules = CompatibilityRuleFactory.get_rules()
|
||||
|
||||
def check_hdim_tile(
|
||||
@@ -846,7 +846,7 @@ class CompatibilityRuleFactoryGfx950(CompatibilityRuleFactoryGfx9):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_rules(cls) -> list[CompatibilityRule]:
|
||||
def get_rules(cls) -> List[CompatibilityRule]:
|
||||
rules = CompatibilityRuleFactoryGfx9.get_rules()
|
||||
|
||||
def check_tile_pipeline(
|
||||
|
||||
@@ -17,6 +17,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
gemm_aquant_quantgrouped_preshufflequant.cpp
|
||||
gemm_bquant_quantgrouped_bf8i4.cpp
|
||||
gemm_bquant_quantgrouped_fp8i4.cpp
|
||||
gemm_bquant_quantgrouped_bf16mxfp4.cpp
|
||||
gemm_bquant_quantgrouped_bf8.cpp
|
||||
gemm_bquant_quantgrouped_fp8.cpp
|
||||
gemm_bquant_quantgrouped_preshuffleb.cpp
|
||||
|
||||
@@ -23,7 +23,7 @@ This folder contains examples of quant GEMMs using the ck_tile tile-programming
|
||||
- **Preshuffled GEMM**: Shuffle the GEMM of B (weight) matrix in the warp layout and bypass the shared memory to do the GEMM calculation. Best performance solution for GEMM.
|
||||
- **TransposeC**: Transpose the C Matrix Output layout to have the best coalesced scale reading
|
||||
- **Preshuffled Quant**: Preshuffle the input matrix to load multiple Quant warp blocks along the selected dimension.
|
||||
- **Precision**: Supports fp16, bf16, fp8, bf8, int4 (for B Matrix).
|
||||
- **Precision**: Supports fp16, bf16, fp8, bf8, int4 (for B Matrix), uint8 (split into two fp4 in the pipeline (for B Matrix)).
|
||||
- **Validation**: CPU/GPU validation and error tolerance options.
|
||||
|
||||
## build
|
||||
@@ -53,7 +53,7 @@ args:
|
||||
-stride_b Tensor B stride (default:0)
|
||||
-stride_c Tensor C stride (default:0)
|
||||
-v 0: No validation, 1: Validation on CPU, 2: Validation on GPU (default:1)
|
||||
-prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, or bf8i4 (default for both AQuant and Bquant: fp8)
|
||||
-prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, bf8i4, or bf16fp4 (default for both AQuant and Bquant: fp8)
|
||||
-warmup Number of iterations before benchmarking the kernel (default:50)
|
||||
-repeat Number of iterations to benchmark the kernel (default:1000)
|
||||
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) , Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::pk_fp4_raw_t>, \
|
||||
TypeConfig, \
|
||||
QuantGroupSize, \
|
||||
ck_tile::QuantType::BQuantGrouped>(arg_parser);
|
||||
|
||||
void bquant_quantgrouped_bf16fp4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf16_t,
|
||||
ck_tile::pk_fp4_raw_t,
|
||||
ck_tile::bf16_t,
|
||||
ck_tile::pk_fp4_raw_t>{});
|
||||
|
||||
lut[hash_multiple_strings(
|
||||
{"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x32"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 32>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
lut[hash_multiple_strings(
|
||||
{"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] =
|
||||
[](const ck_tile::ArgParser& arg_parser) {
|
||||
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
return RUN_GEMM_EXAMPLE_PREC_TYPE;
|
||||
};
|
||||
}
|
||||
@@ -32,7 +32,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("prec",
|
||||
"fp8",
|
||||
"Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, "
|
||||
"or bf8i4; for ABQuant: fp8, bf8")
|
||||
"bf8i4 or bf16fp4; for ABQuant: fp8, bf8")
|
||||
.insert("warmup", "50", "Number of iterations before benchmarking the kernel")
|
||||
.insert("repeat", "1000", "Number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
@@ -109,6 +109,8 @@ void bquant_quantgrouped_fp8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_bf8i4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_bf16fp4_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshuffleb_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut);
|
||||
void bquant_quantgrouped_preshufflequant_instance_factory(
|
||||
@@ -141,6 +143,7 @@ int main(int argc, char* argv[])
|
||||
bquant_quantgrouped_bf8_instance_factory(lut);
|
||||
bquant_quantgrouped_fp8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_bf8i4_instance_factory(lut);
|
||||
bquant_quantgrouped_bf16fp4_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_instance_factory(lut);
|
||||
bquant_quantgrouped_preshufflequant_instance_factory(lut);
|
||||
bquant_quantgrouped_preshuffleb_preshufflequant_instance_factory(lut);
|
||||
|
||||
@@ -69,8 +69,10 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
using ComputeType = std::conditional_t<
|
||||
std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>,
|
||||
ADataType,
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
|
||||
@@ -145,21 +145,25 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>>>;
|
||||
|
||||
using GemmPipeline = std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
std::conditional_t<GemmConfig::PreshuffleQuant == true,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrMem<PipelineProblem>>,
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::BQuantGrouped,
|
||||
std::conditional_t<GemmConfig::PreshuffleB == true,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>,
|
||||
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>>;
|
||||
using GemmPipeline = std::conditional_t < QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::AQuantGrouped,
|
||||
std::conditional_t<GemmConfig::PreshuffleQuant == true,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
ck_tile::AQuantGemmPipelineAgBgCrMem<PipelineProblem>>,
|
||||
std::conditional_t<
|
||||
GemmConfig::PreshuffleB == true,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>,
|
||||
std::conditional_t<
|
||||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
|
||||
ck_tile::MxFp4GemmPipelineAgBgCrCompV3<PipelineProblem>,
|
||||
std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::ABQuantGrouped,
|
||||
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<
|
||||
PipelineProblem,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>>>>;
|
||||
|
||||
constexpr bool TiledPermuteN =
|
||||
(BQuantGroupSize::kN > 1) ? false : GemmConfig::TiledMMAPermuteN;
|
||||
@@ -168,28 +172,31 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
printf(
|
||||
"TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, BQuantGroupSize::kN);
|
||||
}
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<typename TypeConfig::ADataType,
|
||||
typename TypeConfig::BDataType,
|
||||
ck_tile::tuple<>,
|
||||
typename TypeConfig::AccDataType,
|
||||
typename TypeConfig::CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
transpose_c,
|
||||
ck_tile::memory_operation_enum::set,
|
||||
1,
|
||||
false,
|
||||
1,
|
||||
TiledPermuteN>>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
typename TypeConfig::ADataType,
|
||||
std::conditional_t<
|
||||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
|
||||
typename TypeConfig::ADataType,
|
||||
typename TypeConfig::BDataType>,
|
||||
ck_tile::tuple<>,
|
||||
typename TypeConfig::AccDataType,
|
||||
typename TypeConfig::CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
transpose_c,
|
||||
ck_tile::memory_operation_enum::set,
|
||||
1,
|
||||
false,
|
||||
1,
|
||||
TiledPermuteN>>;
|
||||
using Kernel =
|
||||
ck_tile::QuantGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, QuantMode>;
|
||||
|
||||
@@ -226,7 +233,11 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
ck_tile::HostTensor<typename TypeConfig::ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<typename TypeConfig::BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t> ? args.K / 2
|
||||
: args.K,
|
||||
args.N,
|
||||
args.stride_B,
|
||||
is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
@@ -484,7 +495,11 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
int rotating_count = arg_parser.get_int("rotating_count");
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_B = ck_tile::get_default_stride(
|
||||
(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>) ? (K / 2) : K,
|
||||
N,
|
||||
stride_B,
|
||||
is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
|
||||
|
||||
// Conditional stride calculation based on QuantMode
|
||||
@@ -516,8 +531,11 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(ck_tile::host_tensor_descriptor(
|
||||
(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>) ? (K / 2) : K,
|
||||
N,
|
||||
stride_B,
|
||||
is_row_major(b_layout)));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
@@ -563,13 +581,22 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
b_k_n);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{125.f, 130.f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
}
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
|
||||
*bq_tensor_ptr);
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
|
||||
@@ -817,13 +844,23 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
ck_tile::reference_gemm_quant<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
BQuantGroupSize,
|
||||
false>(a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref);
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
|
||||
ck_tile::reference_mxfp4gemm_quant<ADataType,
|
||||
BQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
BQuantGroupSize,
|
||||
false>(
|
||||
a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref);
|
||||
else
|
||||
ck_tile::reference_gemm_quant<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
BQuantGroupSize,
|
||||
false>(a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref);
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
|
||||
{
|
||||
@@ -896,16 +933,18 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if((QuantMode == ck_tile::QuantType::AQuantGrouped ||
|
||||
QuantMode == ck_tile::QuantType::RowColQuant) &&
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>) &&
|
||||
GemmConfig::PreshuffleB)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Preshuffling weight matrix is not supported for AQuant or RowColQuant");
|
||||
"Preshuffling weight matrix is not supported for AQuant, RowColQuant or bf16_fp4_gemm");
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<typename TypeConfig::ADataType, ck_tile::pk_int4_t> ||
|
||||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::fp8_t> ||
|
||||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t>)
|
||||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t> ||
|
||||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf16_t>)
|
||||
{
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// Standalone test program for Old CK GPU references
|
||||
// Tests naive_conv_fwd (existing) and future backward ops
|
||||
|
||||
@@ -4,14 +4,16 @@ This directory contains the builder framework for Composable Kernel, which provi
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Convolution Signature Design](#convolution-signature-design)
|
||||
- [Convolution Signature](#convolution-signature)
|
||||
- [Overview](#overview)
|
||||
- [Architecture](#architecture)
|
||||
- [Core Components](#core-components)
|
||||
- [Concepts and Validation](#concepts-and-validation)
|
||||
- [Convolution Algorithm](#convolution-algorithm)
|
||||
- [Convolution Factory](#convolution-factory)
|
||||
---
|
||||
|
||||
## Convolution Signature Design
|
||||
## Convolution Signature
|
||||
|
||||
### Overview
|
||||
|
||||
@@ -220,25 +222,9 @@ Several fields in the signature are optional:
|
||||
|
||||
This design follows the principle of "make the common case simple, the complex case possible."
|
||||
|
||||
#### Union-Based Layout Representation
|
||||
## Convolution Algorithm
|
||||
|
||||
The `ConvLayout` type uses unions to support dimension-agnostic code:
|
||||
## Convolution Factory
|
||||
|
||||
```cpp
|
||||
struct ConvLayout {
|
||||
union {
|
||||
ConvInputLayout _input_layout;
|
||||
ConvWeightLayout _weight_layout;
|
||||
ConvOutputLayout _output_layout;
|
||||
ConvAuxiliaryTensorLayout _aux_tensor_layout;
|
||||
};
|
||||
// ... constructors for each type
|
||||
};
|
||||
```
|
||||
|
||||
This allows:
|
||||
- Single type to represent all layout variants
|
||||
- Type-safe construction through overloaded constructors
|
||||
- Compile-time enforcement of valid combinations through concepts
|
||||
|
||||
---
|
||||
Convolution factory builds the instance based on the convolution signature and convolution algorithm.
|
||||
The signature and the algorithm descriptions are dispatched to the relevant algorithm specific factory for instance creation. The convolution factory design is described in a separate [Readme](factory/README.md).
|
||||
|
||||
@@ -65,17 +65,19 @@ consteval auto GetTensorDataAndComputeTypes()
|
||||
constexpr auto data_type = Config.data_type;
|
||||
constexpr auto compute_type = Config.compute_type;
|
||||
|
||||
if constexpr(data_type == DataType::UNDEFINDED && compute_type == DataType::UNDEFINDED)
|
||||
using enum DataType;
|
||||
|
||||
if constexpr(data_type == UNDEFINED_DATA_TYPE && compute_type == UNDEFINED_DATA_TYPE)
|
||||
{
|
||||
return std::make_pair(ConvertDataTypeToCK<SignatureDataType>(),
|
||||
ConvertDataTypeToCK<SignatureDataType>());
|
||||
}
|
||||
else if constexpr(data_type == DataType::UNDEFINDED)
|
||||
else if constexpr(data_type == UNDEFINED_DATA_TYPE)
|
||||
{
|
||||
return std::make_pair(ConvertDataTypeToCK<SignatureDataType>(),
|
||||
ConvertDataTypeToCK<compute_type>());
|
||||
}
|
||||
else if constexpr(compute_type == DataType::UNDEFINDED)
|
||||
else if constexpr(compute_type == UNDEFINED_DATA_TYPE)
|
||||
{
|
||||
return std::make_pair(ConvertDataTypeToCK<data_type>(),
|
||||
ConvertDataTypeToCK<SignatureDataType>());
|
||||
@@ -91,7 +93,7 @@ template <DataType SignatureAccDataType, DataType SignatureDataType>
|
||||
consteval auto GetTensorAccumulationType()
|
||||
{
|
||||
constexpr auto data_type = SignatureAccDataType;
|
||||
if constexpr(data_type == DataType::UNDEFINDED)
|
||||
if constexpr(data_type == DataType::UNDEFINED_DATA_TYPE)
|
||||
{
|
||||
return ConvertDataTypeToCK<SignatureDataType>();
|
||||
}
|
||||
@@ -105,7 +107,7 @@ template <auto Config, DataType SignatureDataType>
|
||||
consteval auto GetAuxiliaryTensorDataTypeValue()
|
||||
{
|
||||
constexpr auto data_type = Config.data_type;
|
||||
if constexpr(data_type == DataType::UNDEFINDED)
|
||||
if constexpr(data_type == DataType::UNDEFINED_DATA_TYPE)
|
||||
{
|
||||
return ConvertDataTypeToCK<SignatureDataType>();
|
||||
}
|
||||
|
||||
@@ -316,7 +316,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
|
||||
oss << "," << kABlockTransferSrcScalarPerVector; // 32. ABlockTransferSrcScalarPerVector
|
||||
oss << ","
|
||||
<< kABlockTransferDstScalarPerVectorK1; // 33. ABlockTransferDstScalarPerVector_AK1
|
||||
oss << "," << kABlockLdsExtraM; // 34. ABlockLdsExtraM
|
||||
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 34. ABlockLdsExtraM
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kBThreadClusterLengths); // 35. BBlockTransferThreadClusterLengths
|
||||
@@ -329,10 +329,10 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
|
||||
oss << "," << kBBlockTransferSrcVectorDim; // 38. BBlockTransferSrcVectorDim
|
||||
oss << "," << kBBlockTransferSrcScalarPerVector; // 39. BBlockTransferSrcScalarPerVector
|
||||
oss << ","
|
||||
<< kBBlockTransferDstScalarPerVectorK1; // 40. BBlockTransferDstScalarPerVector_BK1
|
||||
oss << "," << kBBlockLdsExtraN; // 41. BBlockLdsExtraN
|
||||
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 42. CShuffleMXdlPerWavePerShuffle
|
||||
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 43. CShuffleNXdlPerWavePerShuffle
|
||||
<< kBBlockTransferDstScalarPerVectorK1; // 40. BBlockTransferDstScalarPerVector_BK1
|
||||
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 41. BBlockLdsExtraN
|
||||
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 42. CShuffleMXdlPerWavePerShuffle
|
||||
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 43. CShuffleNXdlPerWavePerShuffle
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kCThreadClusterLengths); // 44. CDEBlockTransferClusterLengths
|
||||
|
||||
@@ -316,7 +316,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
|
||||
oss << "," << kABlockTransferSrcScalarPerVector; // 31. ABlockTransferSrcScalarPerVector
|
||||
oss << ","
|
||||
<< kABlockTransferDstScalarPerVectorK1; // 32. ABlockTransferDstScalarPerVector_AK1
|
||||
oss << "," << kABlockLdsExtraM; // 33. ABlockLdsExtraM
|
||||
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 33. ABlockLdsExtraM
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kBThreadClusterLengths); // 34. BBlockTransferThreadClusterLengths
|
||||
@@ -329,10 +329,10 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
|
||||
oss << "," << kBBlockTransferSrcVectorDim; // 37. BBlockTransferSrcVectorDim
|
||||
oss << "," << kBBlockTransferSrcScalarPerVector; // 38. BBlockTransferSrcScalarPerVector
|
||||
oss << ","
|
||||
<< kBBlockTransferDstScalarPerVectorK1; // 39. BBlockTransferDstScalarPerVector_BK1
|
||||
oss << "," << kBBlockLdsExtraN; // 40. BBlockLdsExtraN
|
||||
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 41. CShuffleMXdlPerWavePerShuffle
|
||||
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 42. CShuffleNXdlPerWavePerShuffle
|
||||
<< kBBlockTransferDstScalarPerVectorK1; // 39. BBlockTransferDstScalarPerVector_BK1
|
||||
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 40. BBlockLdsExtraN
|
||||
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 41. CShuffleMXdlPerWavePerShuffle
|
||||
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 42. CShuffleNXdlPerWavePerShuffle
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kCThreadClusterLengths); // 43. CDEBlockTransferClusterLengths
|
||||
|
||||
@@ -311,7 +311,7 @@ struct InstanceTraits<
|
||||
oss << "," << kABlockTransferSrcScalarPerVector; // 32. ABlockTransferSrcScalarPerVector
|
||||
oss << ","
|
||||
<< kABlockTransferDstScalarPerVectorK1; // 33. ABlockTransferDstScalarPerVector_AK1
|
||||
oss << "," << kABlockLdsExtraM; // 34. ABlockLdsExtraM
|
||||
oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 34. ABlockLdsExtraM
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kBThreadClusterLengths); // 35. BBlockTransferThreadClusterLengths
|
||||
@@ -324,10 +324,10 @@ struct InstanceTraits<
|
||||
oss << "," << kBBlockTransferSrcVectorDim; // 38. BBlockTransferSrcVectorDim
|
||||
oss << "," << kBBlockTransferSrcScalarPerVector; // 39. BBlockTransferSrcScalarPerVector
|
||||
oss << ","
|
||||
<< kBBlockTransferDstScalarPerVectorK1; // 40. BBlockTransferDstScalarPerVector_BK1
|
||||
oss << "," << kBBlockLdsExtraN; // 41. BBlockLdsExtraN
|
||||
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 42. CShuffleMXdlPerWavePerShuffle
|
||||
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 43. CShuffleNXdlPerWavePerShuffle
|
||||
<< kBBlockTransferDstScalarPerVectorK1; // 40. BBlockTransferDstScalarPerVector_BK1
|
||||
oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 41. BBlockLdsExtraN
|
||||
oss << "," << kCShuffleMXdlPerWavePerShuffle; // 42. CShuffleMXdlPerWavePerShuffle
|
||||
oss << "," << kCShuffleNXdlPerWavePerShuffle; // 43. CShuffleNXdlPerWavePerShuffle
|
||||
oss << ","
|
||||
<< detail::array_to_string(
|
||||
kCThreadClusterLengths); // 44. CDEBlockTransferClusterLengths
|
||||
|
||||
@@ -13,7 +13,7 @@ namespace ck_tile::builder {
|
||||
|
||||
enum class DataType
|
||||
{
|
||||
UNDEFINDED = 0,
|
||||
UNDEFINED_DATA_TYPE = 0,
|
||||
FP32,
|
||||
FP16,
|
||||
BF16,
|
||||
@@ -25,7 +25,7 @@ enum class DataType
|
||||
|
||||
enum class TensorLayout
|
||||
{
|
||||
UNDEFINED,
|
||||
UNDEFINED_TENSOR_LAYOUT = 0,
|
||||
|
||||
// Bias tensors
|
||||
GC,
|
||||
@@ -212,219 +212,269 @@ enum class ConvAlgorithmSpecialization
|
||||
LARGE_TENSOR
|
||||
};
|
||||
|
||||
// ostream operator overloads for enum classes
|
||||
inline std::ostream& operator<<(std::ostream& os, DataType dt)
|
||||
// toString methods for enum classes
|
||||
inline std::string_view toString(DataType dt)
|
||||
{
|
||||
using enum DataType;
|
||||
switch(dt)
|
||||
{
|
||||
case FP16: return os << "FP16";
|
||||
case FP32: return os << "FP32";
|
||||
case BF16: return os << "BF16";
|
||||
case FP8: return os << "FP8";
|
||||
case INT32: return os << "INT32";
|
||||
case I8: return os << "I8";
|
||||
case U8: return os << "U8";
|
||||
case UNDEFINDED: return os << "UNDEFINDED";
|
||||
default: return os << "Unknown";
|
||||
case FP16: return "FP16";
|
||||
case FP32: return "FP32";
|
||||
case BF16: return "BF16";
|
||||
case FP8: return "FP8";
|
||||
case INT32: return "INT32";
|
||||
case I8: return "I8";
|
||||
case U8: return "U8";
|
||||
case UNDEFINED_DATA_TYPE: return "UNDEFINED_DATA_TYPE";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvDirection dir)
|
||||
inline std::string_view toString(ConvDirection dir)
|
||||
{
|
||||
using enum ConvDirection;
|
||||
switch(dir)
|
||||
{
|
||||
case FORWARD: return os << "Forward";
|
||||
case BACKWARD_DATA: return os << "Backward Data";
|
||||
case BACKWARD_WEIGHT: return os << "Backward Weight";
|
||||
default: return os << "Unknown";
|
||||
case FORWARD: return "Forward";
|
||||
case BACKWARD_DATA: return "Backward Data";
|
||||
case BACKWARD_WEIGHT: return "Backward Weight";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ElementwiseOperation op)
|
||||
inline std::string_view toString(ElementwiseOperation op)
|
||||
{
|
||||
using enum ElementwiseOperation;
|
||||
switch(op)
|
||||
{
|
||||
case CLAMP: return os << "CLAMP";
|
||||
case SCALE: return os << "SCALE";
|
||||
case PASS_THROUGH: return os << "PASS_THROUGH";
|
||||
case BIAS_BNORM_CLAMP: return os << "BIAS_BNORM_CLAMP";
|
||||
case SCALEADD_SCALEADD_RELU: return os << "SCALEADD_SCALEADD_RELU";
|
||||
default: return os << "Unknown";
|
||||
case CLAMP: return "CLAMP";
|
||||
case SCALE: return "SCALE";
|
||||
case PASS_THROUGH: return "PASS_THROUGH";
|
||||
case BIAS_BNORM_CLAMP: return "BIAS_BNORM_CLAMP";
|
||||
case SCALEADD_SCALEADD_RELU: return "SCALEADD_SCALEADD_RELU";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, PipelineVersion ver)
|
||||
inline std::string_view toString(PipelineVersion ver)
|
||||
{
|
||||
using enum PipelineVersion;
|
||||
switch(ver)
|
||||
{
|
||||
case V1: return os << "V1";
|
||||
case V2: return os << "V2";
|
||||
case V3: return os << "V3";
|
||||
case V4: return os << "V4";
|
||||
case V5: return os << "V5";
|
||||
case WEIGHT_ONLY: return os << "WEIGHT_ONLY";
|
||||
default: return os << "Unknown";
|
||||
case V1: return "V1";
|
||||
case V2: return "V2";
|
||||
case V3: return "V3";
|
||||
case V4: return "V4";
|
||||
case V5: return "V5";
|
||||
case WEIGHT_ONLY: return "WEIGHT_ONLY";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GemmSpecialization spec)
|
||||
inline std::string_view toString(GemmSpecialization spec)
|
||||
{
|
||||
using enum GemmSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case Default: return os << "Default";
|
||||
case MPadding: return os << "MPadding";
|
||||
case NPadding: return os << "NPadding";
|
||||
case KPadding: return os << "KPadding";
|
||||
case MNPadding: return os << "MNPadding";
|
||||
case MKPadding: return os << "MKPadding";
|
||||
case NKPadding: return os << "NKPadding";
|
||||
case MNKPadding: return os << "MNKPadding";
|
||||
case OPadding: return os << "OPadding";
|
||||
case MOPadding: return os << "MOPadding";
|
||||
case NOPadding: return os << "NOPadding";
|
||||
case KOPadding: return os << "KOPadding";
|
||||
case MNOPadding: return os << "MNOPadding";
|
||||
case MKOPadding: return os << "MKOPadding";
|
||||
case NKOPadding: return os << "NKOPadding";
|
||||
case MNKOPadding: return os << "MNKOPadding";
|
||||
default: return os << "Unknown";
|
||||
case Default: return "Default";
|
||||
case MPadding: return "MPadding";
|
||||
case NPadding: return "NPadding";
|
||||
case KPadding: return "KPadding";
|
||||
case MNPadding: return "MNPadding";
|
||||
case MKPadding: return "MKPadding";
|
||||
case NKPadding: return "NKPadding";
|
||||
case MNKPadding: return "MNKPadding";
|
||||
case OPadding: return "OPadding";
|
||||
case MOPadding: return "MOPadding";
|
||||
case NOPadding: return "NOPadding";
|
||||
case KOPadding: return "KOPadding";
|
||||
case MNOPadding: return "MNOPadding";
|
||||
case MKOPadding: return "MKOPadding";
|
||||
case NKOPadding: return "NKOPadding";
|
||||
case MNKOPadding: return "MNKOPadding";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvFwdSpecialization spec)
|
||||
inline std::string_view toString(ConvFwdSpecialization spec)
|
||||
{
|
||||
using enum ConvFwdSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case DEFAULT: return os << "DEFAULT";
|
||||
case FILTER_1X1_PAD0: return os << "FILTER_1X1_PAD0";
|
||||
case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0";
|
||||
case FILTER_3x3: return os << "FILTER_3x3";
|
||||
default: return os << "Unknown";
|
||||
case DEFAULT: return "DEFAULT";
|
||||
case FILTER_1X1_PAD0: return "FILTER_1X1_PAD0";
|
||||
case FILTER_1X1_STRIDE1_PAD0: return "FILTER_1X1_STRIDE1_PAD0";
|
||||
case FILTER_3x3: return "FILTER_3x3";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvBwdDataSpecialization spec)
|
||||
inline std::string_view toString(ConvBwdDataSpecialization spec)
|
||||
{
|
||||
using enum ConvBwdDataSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case DEFAULT: return os << "DEFAULT";
|
||||
case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0";
|
||||
default: return os << "Unknown";
|
||||
case DEFAULT: return "DEFAULT";
|
||||
case FILTER_1X1_STRIDE1_PAD0: return "FILTER_1X1_STRIDE1_PAD0";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvBwdWeightSpecialization spec)
|
||||
inline std::string_view toString(ConvBwdWeightSpecialization spec)
|
||||
{
|
||||
using enum ConvBwdWeightSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case DEFAULT: return os << "DEFAULT";
|
||||
case FILTER_1X1_STRIDE1_PAD0: return os << "FILTER_1X1_STRIDE1_PAD0";
|
||||
case FILTER_1X1_PAD0: return os << "FILTER_1X1_PAD0";
|
||||
case ODD_C: return os << "ODD_C";
|
||||
default: return os << "Unknown";
|
||||
case DEFAULT: return "DEFAULT";
|
||||
case FILTER_1X1_STRIDE1_PAD0: return "FILTER_1X1_STRIDE1_PAD0";
|
||||
case FILTER_1X1_PAD0: return "FILTER_1X1_PAD0";
|
||||
case ODD_C: return "ODD_C";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GemmPadding padding)
|
||||
inline std::string_view toString(GemmPadding padding)
|
||||
{
|
||||
using enum GemmPadding;
|
||||
switch(padding)
|
||||
{
|
||||
case DEFAULT: return os << "DEFAULT";
|
||||
case M_PADDING: return os << "M_PADDING";
|
||||
case N_PADDING: return os << "N_PADDING";
|
||||
case K_PADDING: return os << "K_PADDING";
|
||||
case MN_PADDING: return os << "MN_PADDING";
|
||||
case MK_PADDING: return os << "MK_PADDING";
|
||||
case NK_PADDING: return os << "NK_PADDING";
|
||||
case MNK_PADDING: return os << "MNK_PADDING";
|
||||
case O_PADDING: return os << "O_PADDING";
|
||||
case MO_PADDING: return os << "MO_PADDING";
|
||||
case NO_PADDING: return os << "NO_PADDING";
|
||||
case KO_PADDING: return os << "KO_PADDING";
|
||||
case MNO_PADDING: return os << "MNO_PADDING";
|
||||
case MKO_PADDING: return os << "MKO_PADDING";
|
||||
case NKO_PADDING: return os << "NKO_PADDING";
|
||||
case MNKO_PADDING: return os << "MNKO_PADDING";
|
||||
default: return os << "Unknown";
|
||||
case DEFAULT: return "DEFAULT";
|
||||
case M_PADDING: return "M_PADDING";
|
||||
case N_PADDING: return "N_PADDING";
|
||||
case K_PADDING: return "K_PADDING";
|
||||
case MN_PADDING: return "MN_PADDING";
|
||||
case MK_PADDING: return "MK_PADDING";
|
||||
case NK_PADDING: return "NK_PADDING";
|
||||
case MNK_PADDING: return "MNK_PADDING";
|
||||
case O_PADDING: return "O_PADDING";
|
||||
case MO_PADDING: return "MO_PADDING";
|
||||
case NO_PADDING: return "NO_PADDING";
|
||||
case KO_PADDING: return "KO_PADDING";
|
||||
case MNO_PADDING: return "MNO_PADDING";
|
||||
case MKO_PADDING: return "MKO_PADDING";
|
||||
case NKO_PADDING: return "NKO_PADDING";
|
||||
case MNKO_PADDING: return "MNKO_PADDING";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, PipelineScheduler sched)
|
||||
inline std::string_view toString(PipelineScheduler sched)
|
||||
{
|
||||
using enum PipelineScheduler;
|
||||
switch(sched)
|
||||
{
|
||||
case DEFAULT: return os << "DEFAULT";
|
||||
case INTRAWAVE: return os << "INTRAWAVE";
|
||||
case INTERWAVE: return os << "INTERWAVE";
|
||||
default: return os << "Unknown";
|
||||
case DEFAULT: return "DEFAULT";
|
||||
case INTRAWAVE: return "INTRAWAVE";
|
||||
case INTERWAVE: return "INTERWAVE";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, TensorLayout layout)
|
||||
inline std::string_view toString(TensorLayout layout)
|
||||
{
|
||||
using enum TensorLayout;
|
||||
switch(layout)
|
||||
{
|
||||
case GNCW: return os << "GNCW";
|
||||
case GNWC: return os << "GNWC";
|
||||
case NWGC: return os << "NWGC";
|
||||
case NGCW: return os << "NGCW";
|
||||
case G_NW_C_strided: return os << "G_NW_C_strided";
|
||||
case GNCHW: return os << "GNCHW";
|
||||
case GNHWC: return os << "GNHWC";
|
||||
case NHWGC: return os << "NHWGC";
|
||||
case NGCHW: return os << "NGCHW";
|
||||
case G_NHW_C_strided: return os << "G_NHW_C_strided";
|
||||
case GNCDHW: return os << "GNCDHW";
|
||||
case GNDHWC: return os << "GNDHWC";
|
||||
case NDHWGC: return os << "NDHWGC";
|
||||
case NGCDHW: return os << "NGCDHW";
|
||||
case G_NDHW_C_strided: return os << "G_NDHW_C_strided";
|
||||
case GKXC: return os << "GKXC";
|
||||
case GKCX: return os << "GKCX";
|
||||
case KXGC: return os << "KXGC";
|
||||
case G_K_X_C_strided: return os << "G_K_X_C_strided";
|
||||
case GKYXC: return os << "GKYXC";
|
||||
case GKCYX: return os << "GKCYX";
|
||||
case KYXGC: return os << "KYXGC";
|
||||
case G_K_YX_C_strided: return os << "G_K_YX_C_strided";
|
||||
case GKZYXC: return os << "GKZYXC";
|
||||
case GKCZYX: return os << "GKCZYX";
|
||||
case KZYXGC: return os << "KZYXGC";
|
||||
case G_K_ZYX_C_strided: return os << "G_K_ZYX_C_strided";
|
||||
case GNKW: return os << "GNKW";
|
||||
case GNWK: return os << "GNWK";
|
||||
case NWGK: return os << "NWGK";
|
||||
case NGKW: return os << "NGKW";
|
||||
case G_NW_K_strided: return os << "G_NW_K_strided";
|
||||
case GNKHW: return os << "GNKHW";
|
||||
case GNHWK: return os << "GNHWK";
|
||||
case NHWGK: return os << "NHWGK";
|
||||
case NGKHW: return os << "NGKHW";
|
||||
case G_NHW_K_strided: return os << "G_NHW_K_strided";
|
||||
case GNKDHW: return os << "GNKDHW";
|
||||
case GNDHWK: return os << "GNDHWK";
|
||||
case NDHWGK: return os << "NDHWGK";
|
||||
case NGKDHW: return os << "NGKDHW";
|
||||
case G_NDHW_K_strided: return os << "G_NDHW_K_strided";
|
||||
case GC: return os << "GC";
|
||||
case G_C_strided: return os << "G_C_strided";
|
||||
case G_K_strided: return os << "G_K_strided";
|
||||
case UNDEFINED: return os << "UNDEFINED";
|
||||
default: return os << "Unknown";
|
||||
case GNCW: return "GNCW";
|
||||
case GNWC: return "GNWC";
|
||||
case NWGC: return "NWGC";
|
||||
case NGCW: return "NGCW";
|
||||
case G_NW_C_strided: return "G_NW_C_strided";
|
||||
case GNCHW: return "GNCHW";
|
||||
case GNHWC: return "GNHWC";
|
||||
case NHWGC: return "NHWGC";
|
||||
case NGCHW: return "NGCHW";
|
||||
case G_NHW_C_strided: return "G_NHW_C_strided";
|
||||
case GNCDHW: return "GNCDHW";
|
||||
case GNDHWC: return "GNDHWC";
|
||||
case NDHWGC: return "NDHWGC";
|
||||
case NGCDHW: return "NGCDHW";
|
||||
case G_NDHW_C_strided: return "G_NDHW_C_strided";
|
||||
case GKXC: return "GKXC";
|
||||
case GKCX: return "GKCX";
|
||||
case KXGC: return "KXGC";
|
||||
case G_K_X_C_strided: return "G_K_X_C_strided";
|
||||
case GKYXC: return "GKYXC";
|
||||
case GKCYX: return "GKCYX";
|
||||
case KYXGC: return "KYXGC";
|
||||
case G_K_YX_C_strided: return "G_K_YX_C_strided";
|
||||
case GKZYXC: return "GKZYXC";
|
||||
case GKCZYX: return "GKCZYX";
|
||||
case KZYXGC: return "KZYXGC";
|
||||
case G_K_ZYX_C_strided: return "G_K_ZYX_C_strided";
|
||||
case GNKW: return "GNKW";
|
||||
case GNWK: return "GNWK";
|
||||
case NWGK: return "NWGK";
|
||||
case NGKW: return "NGKW";
|
||||
case G_NW_K_strided: return "G_NW_K_strided";
|
||||
case GNKHW: return "GNKHW";
|
||||
case GNHWK: return "GNHWK";
|
||||
case NHWGK: return "NHWGK";
|
||||
case NGKHW: return "NGKHW";
|
||||
case G_NHW_K_strided: return "G_NHW_K_strided";
|
||||
case GNKDHW: return "GNKDHW";
|
||||
case GNDHWK: return "GNDHWK";
|
||||
case NDHWGK: return "NDHWGK";
|
||||
case NGKDHW: return "NGKDHW";
|
||||
case G_NDHW_K_strided: return "G_NDHW_K_strided";
|
||||
case GC: return "GC";
|
||||
case G_C_strided: return "G_C_strided";
|
||||
case G_K_strided: return "G_K_strided";
|
||||
case UNDEFINED_TENSOR_LAYOUT: return "UNDEFINED_TENSOR_LAYOUT";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
// ostream operator overloads for enum classes
|
||||
inline std::ostream& operator<<(std::ostream& os, DataType dt) { return os << toString(dt); }
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvDirection dir) { return os << toString(dir); }
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ElementwiseOperation op)
|
||||
{
|
||||
return os << toString(op);
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, PipelineVersion ver)
|
||||
{
|
||||
return os << toString(ver);
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GemmSpecialization spec)
|
||||
{
|
||||
return os << toString(spec);
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvFwdSpecialization spec)
|
||||
{
|
||||
return os << toString(spec);
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvBwdDataSpecialization spec)
|
||||
{
|
||||
return os << toString(spec);
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, ConvBwdWeightSpecialization spec)
|
||||
{
|
||||
return os << toString(spec);
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, GemmPadding padding)
|
||||
{
|
||||
return os << toString(padding);
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, PipelineScheduler sched)
|
||||
{
|
||||
return os << toString(sched);
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, TensorLayout layout)
|
||||
{
|
||||
return os << toString(layout);
|
||||
}
|
||||
|
||||
// ostream operator overload for std::variant of convolution specializations
|
||||
inline std::ostream& operator<<(std::ostream& os,
|
||||
const std::variant<ConvFwdSpecialization,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -13,15 +14,19 @@ using namespace ck_tile::builder::test_utils;
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_1D_BF16_ChannelsFirst_scale)
|
||||
{
|
||||
using enum ck_tile::builder::ConvDirection;
|
||||
using enum ck_tile::builder::DataType;
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
using enum ck_tile::builder::ElementwiseOperation;
|
||||
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 1,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::BF16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NGCW}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NGKW},
|
||||
.operation = {.elementwise_operation = ElementwiseOperation::SCALE}}};
|
||||
.direction = FORWARD,
|
||||
.data_type = BF16,
|
||||
.accumulation_data_type = FP32,
|
||||
.input = {.config = {.layout = NGCW}},
|
||||
.weight = {.config = {.layout = GKXC}},
|
||||
.output = {.config = {.layout = NGKW}, .operation = {.elementwise_operation = SCALE}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
@@ -33,8 +38,10 @@ TEST(FwdConvInstances,
|
||||
.with_block_gemm(BlockGemmDesc_v2_intrawave);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
const auto expected_transfer_parameters = to_string(FwdConvAlgorithm);
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256,256,256,32",
|
||||
expected_transfer_parameters,
|
||||
"NGCW,GKXC,EmptyTuple,NGKW",
|
||||
"PassThrough,PassThrough,Scale",
|
||||
"Filter1x1Stride1Pad0",
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -12,13 +13,17 @@ using namespace ck_tile::builder::test_utils;
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_1D_FP16_ChannelsFirst)
|
||||
{
|
||||
using enum ck_tile::builder::ConvDirection;
|
||||
using enum ck_tile::builder::DataType;
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 1,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NWGK}}};
|
||||
.direction = FORWARD,
|
||||
.data_type = FP16,
|
||||
.accumulation_data_type = FP32,
|
||||
.input = {.config = {.layout = NWGC}},
|
||||
.weight = {.config = {.layout = GKXC}},
|
||||
.output = {.config = {.layout = NWGK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
|
||||
@@ -29,11 +34,13 @@ TEST(FwdConvInstances,
|
||||
.with_prefetch_config(1, 2, PipelineScheduler::DEFAULT);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
const auto expected_transfer_parameters = to_string(FwdConvAlgorithm);
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle",
|
||||
expected_transfer_parameters,
|
||||
"NWGC,GKXC,EmptyTuple,NWGK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"MNKPadding",
|
||||
"64,64,32,32",
|
||||
"Default"});
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -14,13 +15,17 @@ using namespace ck_tile::builder::test_utils;
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Instance_1D_FP32_ChannelsFirst_scale)
|
||||
{
|
||||
using enum ck_tile::builder::ConvDirection;
|
||||
using enum ck_tile::builder::DataType;
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 1,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::I8,
|
||||
.accumulation_data_type = DataType::INT32,
|
||||
.input = {.config = {.layout = TensorLayout::GNWC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKXC}},
|
||||
.output = {.config = {.layout = TensorLayout::GNWK}}};
|
||||
.direction = FORWARD,
|
||||
.data_type = I8,
|
||||
.accumulation_data_type = INT32,
|
||||
.input = {.config = {.layout = GNWC}},
|
||||
.weight = {.config = {.layout = GKXC}},
|
||||
.output = {.config = {.layout = GNWK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle{}
|
||||
@@ -31,8 +36,10 @@ TEST(FwdConvInstances,
|
||||
.with_prefetch_config(1, 0, PipelineScheduler::DEFAULT);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
const auto expected_transfer_parameters = to_string(FwdConvAlgorithm);
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleD_Wmma_CShuffle",
|
||||
"128,64,64,64",
|
||||
expected_transfer_parameters,
|
||||
"GNWC,GKXC,EmptyTuple,GNWK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"Default"});
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -12,13 +13,17 @@ using namespace ck_tile::builder::test_utils;
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_ChannelsLast)
|
||||
{
|
||||
using enum ck_tile::builder::ConvDirection;
|
||||
using enum ck_tile::builder::DataType;
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::BF16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NHWGK}}};
|
||||
.direction = FORWARD,
|
||||
.data_type = BF16,
|
||||
.accumulation_data_type = FP32,
|
||||
.input = {.config = {.layout = NHWGC}},
|
||||
.weight = {.config = {.layout = GKYXC}},
|
||||
.output = {.config = {.layout = NHWGK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
@@ -29,8 +34,10 @@ TEST(FwdConvInstances,
|
||||
.with_block_gemm(BlockGemmDesc_v1_intrawave);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
const auto expected_transfer_parameters = to_string(FwdConvAlgorithm);
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256,256,256,32",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"NHWGC,GKYXC,EmptyTuple,NHWGK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
@@ -43,13 +50,17 @@ TEST(FwdConvInstances,
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_NHWGC_Filter3x3)
|
||||
{
|
||||
using enum ck_tile::builder::ConvDirection;
|
||||
using enum ck_tile::builder::DataType;
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::BF16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NHWGK}}};
|
||||
.direction = FORWARD,
|
||||
.data_type = BF16,
|
||||
.accumulation_data_type = FP32,
|
||||
.input = {.config = {.layout = NHWGC}},
|
||||
.weight = {.config = {.layout = GKYXC}},
|
||||
.output = {.config = {.layout = NHWGK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
@@ -60,7 +71,10 @@ TEST(FwdConvInstances,
|
||||
.with_block_gemm(BlockGemmDesc_v5_intrawave);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
const auto expected_transfer_parameters = to_string(FwdConvAlgorithm);
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
expected_transfer_parameters,
|
||||
"Filter3x3",
|
||||
"NHWGC,GKYXC,EmptyTuple,NHWGK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -12,19 +13,22 @@ using namespace ck_tile::builder::test_utils;
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_2D_BF16_scale_add_relu)
|
||||
{
|
||||
using enum ck_tile::builder::ConvDirection;
|
||||
using enum ck_tile::builder::DataType;
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
using enum ck_tile::builder::ElementwiseOperation;
|
||||
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::BF16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC, .data_type = DataType::BF16}},
|
||||
.output = ConvolutionTensor{
|
||||
.config = {.layout = TensorLayout::NHWGK},
|
||||
.operation = TensorOperation<>{.elementwise_operation =
|
||||
ElementwiseOperation::SCALEADD_SCALEADD_RELU}
|
||||
.with_auxiliary_operand_configs<TensorLayout::NHWGK,
|
||||
TensorLayout::G_K_strided>()}};
|
||||
.direction = FORWARD,
|
||||
.data_type = BF16,
|
||||
.accumulation_data_type = FP32,
|
||||
.input = {.config = {.layout = NHWGC}},
|
||||
.weight = {.config = {.layout = GKYXC, .data_type = BF16}},
|
||||
.output = ConvolutionTensor{
|
||||
.config = {.layout = NHWGK},
|
||||
.operation = TensorOperation<>{.elementwise_operation = SCALEADD_SCALEADD_RELU}
|
||||
.with_auxiliary_operand_configs<NHWGK, G_K_strided>()}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
|
||||
@@ -35,10 +39,12 @@ TEST(FwdConvInstances,
|
||||
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
const auto expected_transfer_parameters = to_string(FwdConvAlgorithm);
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle",
|
||||
expected_transfer_parameters,
|
||||
"NHWGC,GKYXC,Tuple(NHWGK,G_K),NHWGK",
|
||||
"PassThrough,PassThrough,ScaleAddScaleAddRelu",
|
||||
"64,64,32,32",
|
||||
"MNKPadding",
|
||||
"Default"});
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -10,13 +11,17 @@ using namespace ck_tile::builder::test_utils;
|
||||
|
||||
TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_GNHWC)
|
||||
{
|
||||
using enum ck_tile::builder::ConvDirection;
|
||||
using enum ck_tile::builder::DataType;
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::GNHWK}}};
|
||||
.direction = FORWARD,
|
||||
.data_type = FP16,
|
||||
.accumulation_data_type = FP32,
|
||||
.input = {.config = {.layout = GNHWC}},
|
||||
.weight = {.config = {.layout = GKYXC}},
|
||||
.output = {.config = {.layout = GNHWK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{}
|
||||
@@ -27,8 +32,10 @@ TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Ins
|
||||
.with_dl_transfer(DlFwdTransfer);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
const auto expected_transfer_parameters = to_string(FwdConvAlgorithm);
|
||||
run_test<Builder>({"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK",
|
||||
"256,128,128,16",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"MNKPadding",
|
||||
"GNHWC,GKYXC,EmptyTuple,GNHWK",
|
||||
@@ -38,13 +45,17 @@ TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Ins
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_FILTER_1X1_PAD0)
|
||||
{
|
||||
using enum ck_tile::builder::ConvDirection;
|
||||
using enum ck_tile::builder::DataType;
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::GNHWK}}};
|
||||
.direction = FORWARD,
|
||||
.data_type = FP16,
|
||||
.accumulation_data_type = FP32,
|
||||
.input = {.config = {.layout = GNHWC}},
|
||||
.weight = {.config = {.layout = GKYXC}},
|
||||
.output = {.config = {.layout = GNHWK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{}
|
||||
@@ -56,8 +67,10 @@ TEST(FwdConvInstances,
|
||||
.with_dl_transfer(DlFwdTransfer);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
const auto expected_transfer_parameters = to_string(FwdConvAlgorithm);
|
||||
run_test<Builder>({"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK",
|
||||
"256,128,128,16",
|
||||
expected_transfer_parameters,
|
||||
"Filter1x1Pad0",
|
||||
"MNKPadding",
|
||||
"GNHWC,GKYXC,EmptyTuple,GNHWK",
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -11,13 +12,17 @@ using namespace ck_tile::builder::test_utils;
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP16_GNHWC)
|
||||
{
|
||||
using enum ck_tile::builder::ConvDirection;
|
||||
using enum ck_tile::builder::DataType;
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::GNHWK}}};
|
||||
.direction = FORWARD,
|
||||
.data_type = FP16,
|
||||
.accumulation_data_type = FP32,
|
||||
.input = {.config = {.layout = GNHWC}},
|
||||
.weight = {.config = {.layout = GKYXC}},
|
||||
.output = {.config = {.layout = GNHWK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
@@ -29,8 +34,10 @@ TEST(FwdConvInstances,
|
||||
.with_block_gemm(BlockGemmDesc_v3_intrawave);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
const auto expected_transfer_parameters = to_string(FwdConvAlgorithm);
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256,256,256,32",
|
||||
expected_transfer_parameters,
|
||||
"Filter1x1Pad0",
|
||||
"Intrawave",
|
||||
"v3",
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -11,13 +12,17 @@ using namespace ck_tile::builder::test_utils;
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP32_NGCHW_GKCYX)
|
||||
{
|
||||
using enum ck_tile::builder::ConvDirection;
|
||||
using enum ck_tile::builder::DataType;
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::FP32,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NGCHW}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKCYX}},
|
||||
.output = {.config = {.layout = TensorLayout::NGKHW}}};
|
||||
.direction = FORWARD,
|
||||
.data_type = FP32,
|
||||
.accumulation_data_type = FP32,
|
||||
.input = {.config = {.layout = NGCHW}},
|
||||
.weight = {.config = {.layout = GKCYX}},
|
||||
.output = {.config = {.layout = NGKHW}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
@@ -29,8 +34,10 @@ TEST(FwdConvInstances,
|
||||
.with_block_gemm(BlockGemmDesc_v4_intrawave);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
const auto expected_transfer_parameters = to_string(FwdConvAlgorithm);
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256,128,128,32",
|
||||
expected_transfer_parameters,
|
||||
"Filter1x1Stride1Pad0",
|
||||
"Intrawave",
|
||||
"v4",
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -12,13 +13,17 @@ using namespace ck_tile::builder::test_utils;
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_Instance_2D_FP8_ChannelsLast)
|
||||
{
|
||||
using enum ck_tile::builder::ConvDirection;
|
||||
using enum ck_tile::builder::DataType;
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::FP8,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NHWGK}}};
|
||||
.direction = FORWARD,
|
||||
.data_type = FP8,
|
||||
.accumulation_data_type = FP32,
|
||||
.input = {.config = {.layout = NHWGC}},
|
||||
.weight = {.config = {.layout = GKYXC}},
|
||||
.output = {.config = {.layout = NHWGK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle{}
|
||||
@@ -29,8 +34,10 @@ TEST(FwdConvInstances,
|
||||
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
const auto expected_transfer_parameters = to_string(FwdConvAlgorithm);
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle",
|
||||
"256,256,128,32",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"NHWGC,GKYXC,EmptyTuple,NHWGK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -11,13 +12,17 @@ using namespace ck_tile::builder::test_utils;
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC)
|
||||
{
|
||||
using enum ck_tile::builder::ConvDirection;
|
||||
using enum ck_tile::builder::DataType;
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::GNHWK}}};
|
||||
.direction = FORWARD,
|
||||
.data_type = FP16,
|
||||
.accumulation_data_type = FP32,
|
||||
.input = {.config = {.layout = GNHWC}},
|
||||
.weight = {.config = {.layout = GKYXC}},
|
||||
.output = {.config = {.layout = GNHWK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{
|
||||
@@ -30,8 +35,10 @@ TEST(FwdConvInstances,
|
||||
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
const auto expected_transfer_parameters = to_string(FwdConvAlgorithm);
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor",
|
||||
"256,256,128,32",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"GNHWC,GKYXC,EmptyTuple,GNHWK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
@@ -42,13 +49,17 @@ TEST(
|
||||
FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor_Instance_2D_FP16_GNHWC_Filter1x1Pad0)
|
||||
{
|
||||
using enum ck_tile::builder::ConvDirection;
|
||||
using enum ck_tile::builder::DataType;
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::GNHWK}}};
|
||||
.direction = FORWARD,
|
||||
.data_type = FP16,
|
||||
.accumulation_data_type = FP32,
|
||||
.input = {.config = {.layout = GNHWC}},
|
||||
.weight = {.config = {.layout = GKYXC}},
|
||||
.output = {.config = {.layout = GNHWK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor{
|
||||
@@ -61,8 +72,10 @@ TEST(
|
||||
.with_prefetch_config(1, 1, PipelineScheduler::DEFAULT)};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
const auto expected_transfer_parameters = to_string(FwdConvAlgorithm);
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor",
|
||||
"128,128,128,32",
|
||||
expected_transfer_parameters,
|
||||
"Filter1x1Pad0",
|
||||
"GNHWC,GKYXC,EmptyTuple,GNHWK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -12,14 +13,17 @@ using namespace ck_tile::builder::test_utils;
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_BF16_GNDHWC)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::BF16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::GNDHWC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKZYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::GNDHWK}}};
|
||||
using enum ck_tile::builder::ConvDirection;
|
||||
using enum ck_tile::builder::DataType;
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 3,
|
||||
.direction = FORWARD,
|
||||
.data_type = BF16,
|
||||
.accumulation_data_type = FP32,
|
||||
.input = {.config = {.layout = GNDHWC}},
|
||||
.weight = {.config = {.layout = GKZYXC}},
|
||||
.output = {.config = {.layout = GNDHWK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
@@ -30,8 +34,10 @@ TEST(FwdConvInstances,
|
||||
.with_block_gemm(BlockGemmDesc_v3_intrawave);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
const auto expected_transfer_parameters = to_string(FwdConvAlgorithm);
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256,256,256,32",
|
||||
expected_transfer_parameters,
|
||||
"Default",
|
||||
"Intrawave",
|
||||
"v3",
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -12,14 +13,17 @@ using namespace ck_tile::builder::test_utils;
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP16_NDHWGC_ChannelsLast)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NDHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKZYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NDHWGK}}};
|
||||
using enum ck_tile::builder::ConvDirection;
|
||||
using enum ck_tile::builder::DataType;
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 3,
|
||||
.direction = FORWARD,
|
||||
.data_type = FP16,
|
||||
.accumulation_data_type = FP32,
|
||||
.input = {.config = {.layout = NDHWGC}},
|
||||
.weight = {.config = {.layout = GKZYXC}},
|
||||
.output = {.config = {.layout = NDHWGK}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
@@ -31,8 +35,10 @@ TEST(FwdConvInstances,
|
||||
.with_block_gemm(BlockGemmDesc_v4_intrawave);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
const auto expected_transfer_parameters = to_string(FwdConvAlgorithm);
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256,128,128,32",
|
||||
expected_transfer_parameters,
|
||||
"Filter1x1Pad0",
|
||||
"Intrawave",
|
||||
"v4",
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -12,14 +13,17 @@ using namespace ck_tile::builder::test_utils;
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst)
|
||||
{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::FP32,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NGCDHW}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKCZYX}},
|
||||
.output = {.config = {.layout = TensorLayout::NGKDHW}}};
|
||||
using enum ck_tile::builder::ConvDirection;
|
||||
using enum ck_tile::builder::DataType;
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
|
||||
constexpr ConvSignature FwdConvSignature{.spatial_dim = 3,
|
||||
.direction = FORWARD,
|
||||
.data_type = FP32,
|
||||
.accumulation_data_type = FP32,
|
||||
.input = {.config = {.layout = NGCDHW}},
|
||||
.weight = {.config = {.layout = GKCZYX}},
|
||||
.output = {.config = {.layout = NGKDHW}}};
|
||||
|
||||
constexpr auto FwdConvAlgorithm =
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3{}
|
||||
@@ -31,8 +35,10 @@ TEST(FwdConvInstances,
|
||||
.with_block_gemm(BlockGemmDesc_v1_intrawave);
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
const auto expected_transfer_parameters = to_string(FwdConvAlgorithm);
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256,256,256,32",
|
||||
expected_transfer_parameters,
|
||||
"Filter1x1Pad0",
|
||||
"Intrawave",
|
||||
"v1",
|
||||
|
||||
@@ -12,6 +12,12 @@
|
||||
|
||||
namespace {
|
||||
|
||||
using ck_tile::builder::ConvDirection;
|
||||
using ck_tile::builder::DataType;
|
||||
using ck_tile::builder::ElementwiseOperation;
|
||||
using ck_tile::builder::PipelineScheduler;
|
||||
using ck_tile::builder::PipelineVersion;
|
||||
using ck_tile::builder::TensorLayout;
|
||||
using ::testing::ElementsAre;
|
||||
|
||||
// Test fixture for ConvTraits tests
|
||||
@@ -84,15 +90,13 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction)
|
||||
|
||||
// Verify signature information
|
||||
EXPECT_EQ(Traits::spatial_dim, 2);
|
||||
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
|
||||
EXPECT_EQ(Traits::direction, ConvDirection::FORWARD);
|
||||
EXPECT_THAT(Traits::layout,
|
||||
::testing::ElementsAre(ck_tile::builder::TensorLayout::GNHWC,
|
||||
ck_tile::builder::TensorLayout::GKYXC,
|
||||
ck_tile::builder::TensorLayout::GNHWK));
|
||||
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
|
||||
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK));
|
||||
EXPECT_EQ(Traits::data_type, DataType::FP16);
|
||||
EXPECT_EQ(Traits::input_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::weight_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::output_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
|
||||
// Verify specializations
|
||||
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
|
||||
@@ -145,8 +149,8 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction)
|
||||
EXPECT_EQ(Traits::c_tile_transfer.scalar_per_vector, 8);
|
||||
|
||||
// Verify pipeline configuration
|
||||
EXPECT_EQ(Traits::pipeline_scheduler, ck_tile::builder::PipelineScheduler::INTRAWAVE);
|
||||
EXPECT_EQ(Traits::pipeline_version, ck_tile::builder::PipelineVersion::V1);
|
||||
EXPECT_EQ(Traits::pipeline_scheduler, PipelineScheduler::INTRAWAVE);
|
||||
EXPECT_EQ(Traits::pipeline_version, PipelineVersion::V1);
|
||||
}
|
||||
|
||||
// Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
@@ -214,15 +218,13 @@ TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction)
|
||||
|
||||
// Verify signature information
|
||||
EXPECT_EQ(Traits::spatial_dim, 2);
|
||||
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
|
||||
EXPECT_EQ(Traits::direction, ConvDirection::FORWARD);
|
||||
EXPECT_THAT(Traits::layout,
|
||||
::testing::ElementsAre(ck_tile::builder::TensorLayout::GNHWC,
|
||||
ck_tile::builder::TensorLayout::GKYXC,
|
||||
ck_tile::builder::TensorLayout::GNHWK));
|
||||
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
|
||||
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK));
|
||||
EXPECT_EQ(Traits::data_type, DataType::FP16);
|
||||
EXPECT_EQ(Traits::input_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::weight_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::output_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
|
||||
// Verify specializations
|
||||
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
|
||||
@@ -300,15 +302,13 @@ TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction)
|
||||
|
||||
// Verify signature information
|
||||
EXPECT_EQ(Traits::spatial_dim, 2);
|
||||
EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD);
|
||||
EXPECT_EQ(Traits::direction, ConvDirection::FORWARD);
|
||||
EXPECT_THAT(Traits::layout,
|
||||
::testing::ElementsAre(ck_tile::builder::TensorLayout::GNHWC,
|
||||
ck_tile::builder::TensorLayout::GKYXC,
|
||||
ck_tile::builder::TensorLayout::GNHWK));
|
||||
EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16);
|
||||
EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH);
|
||||
ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK));
|
||||
EXPECT_EQ(Traits::data_type, DataType::FP16);
|
||||
EXPECT_EQ(Traits::input_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::weight_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
EXPECT_EQ(Traits::output_element_op, ElementwiseOperation::PASS_THROUGH);
|
||||
|
||||
// Verify specializations
|
||||
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
|
||||
|
||||
@@ -14,8 +14,8 @@ struct TensorConfig
|
||||
{
|
||||
TensorLayout layout;
|
||||
// Optional data types, override the type defined in the signature if provided.
|
||||
DataType data_type{DataType::UNDEFINDED};
|
||||
DataType compute_type{DataType::UNDEFINDED};
|
||||
DataType data_type{DataType::UNDEFINED_DATA_TYPE};
|
||||
DataType compute_type{DataType::UNDEFINED_DATA_TYPE};
|
||||
};
|
||||
|
||||
template <TensorConfig... Configs>
|
||||
|
||||
@@ -76,54 +76,54 @@ struct F8_ConvScale
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>"
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>"
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
@@ -141,54 +141,54 @@ struct F8_BF8_comb1_ConvScale
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>"
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,bf8,Default,1>"
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
@@ -206,54 +206,54 @@ struct F8_BF8_comb2_ConvScale
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>"
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,bf8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,bf8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,bf8,Default,1>"
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
@@ -271,54 +271,54 @@ struct F8_BF8_comb3_ConvScale
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>"
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,bf8,fp8,Default,1>"
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
@@ -336,54 +336,54 @@ struct F8_float_CombConvScale
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>"
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>"
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
@@ -401,54 +401,54 @@ struct F8_ConvScaleRelu
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>"
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvScaleRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>"
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
@@ -466,54 +466,54 @@ struct F8_CombConvScaleRelu
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>"
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,UnaryCombinedOp,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>"
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
@@ -531,54 +531,54 @@ struct F8_ConvScaleAdd
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>"
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK),NDHWGK,fp8,fp8,fp32,fp32,Tuple(fp32),fp8,PassThrough,PassThrough,ConvScaleAdd,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>"
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
@@ -596,54 +596,54 @@ struct F8_ConvInvscale
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>"
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,32,32,8,8,32,32,2,1,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,128,128,64,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,128,32,128,32,8,8,32,32,1,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,128,64,128,32,8,8,32,32,2,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,256,32,8,8,32,32,2,4,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,128,64,32,8,8,32,32,2,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,256,64,128,32,8,8,32,32,1,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,64,32,64,32,8,8,32,32,1,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp8,fp8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp8,fp8,fp32,fp32,EmptyTuple,fp8,PassThrough,PassThrough,ConvInvscale,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),8,fp8,fp8,Default,1>"
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
@@ -85,9 +85,9 @@ struct DyOp_F32_2
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>"
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>"
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
@@ -98,9 +98,9 @@ struct DyOp_F32_3
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>"
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp32,fp32,fp32,fp32,EmptyTuple,fp32,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>"
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
@@ -111,9 +111,9 @@ struct DyOp_F16_2
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>"
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>"
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
@@ -124,9 +124,9 @@ struct DyOp_F16_3
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>"
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,fp16,fp16,fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>"
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
@@ -137,9 +137,9 @@ struct DyOp_BF16_2
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>"
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>"
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
@@ -150,9 +150,9 @@ struct DyOp_BF16_3
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>"
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,bf16,bf16,fp32,bf16,EmptyTuple,bf16,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>"
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
@@ -163,9 +163,9 @@ struct DyOp_INT8_2
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>"
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>"
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
@@ -176,9 +176,9 @@ struct DyOp_INT8_3
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>"
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,s8,s8,s32,s8,EmptyTuple,s8,PassThrough,PassThrough,DynamicUnaryOp,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>"
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
@@ -53,18 +53,18 @@ struct F32
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>"
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp32,fp32),Tuple(fp32,fp32),fp32,fp32,EmptyTuple,fp32,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>"
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
@@ -75,18 +75,18 @@ struct F16
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>"
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(fp16,fp16),Tuple(fp16,fp16),fp32,fp16,EmptyTuple,fp16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>"
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
@@ -97,18 +97,18 @@ struct BF16
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>"
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(bf16,bf16),Tuple(bf16,bf16),fp32,bf16,EmptyTuple,bf16,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>"
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
@@ -119,18 +119,18 @@ struct S8
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>"
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,EmptyTuple,NDHWGK,Tuple(s8,s8),Tuple(s8,s8),s32,s8,EmptyTuple,s8,ScaleAdd,ScaleAdd,PassThrough,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,1,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>"
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
@@ -54,18 +54,18 @@ struct F32
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,1,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>"
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,16,4,4,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,16,4,4,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,16,1,16),4,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,16,4,4,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp32,fp32,fp32,fp32,Tuple(fp32,fp32),fp32,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,16,4,4,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,4,true,1,1,Seq(1,8,1,8),1,fp32,fp32,Default,1>"
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
@@ -76,18 +76,18 @@ struct F16
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>"
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,fp16,fp16,fp32,fp16,Tuple(fp16,fp16),fp16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,fp16,fp16,Default,1>"
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
@@ -98,18 +98,18 @@ struct BF16
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>"
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,bf16,bf16,fp32,bf16,Tuple(bf16,bf16),bf16,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,bf16,bf16,Default,1>"
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
@@ -120,18 +120,18 @@ struct S8
|
||||
|
||||
constexpr static auto expected = {
|
||||
// clang-format off
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,1,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>"
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Default,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,128,128,32,8,8,32,32,2,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,256,256,128,32,8,8,32,32,4,2,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,32,1,8),8,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,32,32,8,8,32,32,2,1,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>",
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<3,NDHWGC,GKZYXC,Tuple(NDHWGK,G_K),NDHWGK,s8,s8,s32,s8,Tuple(fp32,fp32),s8,PassThrough,PassThrough,ScaleAddScaleAddRelu,Filter1x1Stride1Pad0,MNKPadding,1,64,64,64,32,8,8,32,32,2,2,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,Seq(4,16,1),Seq(1,0,2),Seq(1,0,2),2,1,8,true,1,1,Seq(1,16,1,4),1,s8,s8,Default,1>"
|
||||
// clang-format on
|
||||
};
|
||||
};
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
@@ -30,8 +30,8 @@ static_assert(!ckb::TensorOperatorDescriptor<InvalidTensorOp>);
|
||||
struct TensorConfig
|
||||
{
|
||||
ckb::TensorLayout layout;
|
||||
ckb::DataType data_type{ckb::DataType::UNDEFINDED};
|
||||
ckb::DataType compute_type{ckb::DataType::UNDEFINDED};
|
||||
ckb::DataType data_type{ckb::DataType::UNDEFINED_DATA_TYPE};
|
||||
ckb::DataType compute_type{ckb::DataType::UNDEFINED_DATA_TYPE};
|
||||
};
|
||||
|
||||
struct ConvTensorSimple
|
||||
@@ -55,39 +55,49 @@ struct ConvTensorWithInvalidOp
|
||||
// This includes dimensionality, direction, data layout, and data type.
|
||||
struct ConvSignature
|
||||
{
|
||||
using enum ckb::DataType;
|
||||
using enum ckb::TensorLayout;
|
||||
|
||||
int spatial_dim = 2;
|
||||
ckb::DataType data_type = ckb::DataType::FP16;
|
||||
ckb::DataType accumulation_data_type = ckb::DataType::FP32;
|
||||
ConvTensorSimple input = {.config = {ckb::TensorLayout::GNHWC}};
|
||||
ConvTensorSimple weight = {.config = {ckb::TensorLayout::GKYXC}};
|
||||
ConvTensorSimple output = {.config = {ckb::TensorLayout::GNHWK}};
|
||||
ckb::DataType data_type = FP16;
|
||||
ckb::DataType accumulation_data_type = FP32;
|
||||
ConvTensorSimple input = {.config = {GNHWC}};
|
||||
ConvTensorSimple weight = {.config = {GKYXC}};
|
||||
ConvTensorSimple output = {.config = {GNHWK}};
|
||||
};
|
||||
static_assert(ckb::ConvSignatureDescriptor<ConvSignature>);
|
||||
|
||||
// Compile time tests for concepts
|
||||
struct ConvSignatureWithOptionalParams
|
||||
{
|
||||
using enum ckb::DataType;
|
||||
using enum ckb::TensorLayout;
|
||||
using enum ckb::ConvDirection;
|
||||
using enum ckb::ElementwiseOperation;
|
||||
|
||||
int spatial_dim = 2;
|
||||
ckb::DataType data_type = ckb::DataType::FP16;
|
||||
ckb::DataType accumulation_data_type = ckb::DataType::FP32;
|
||||
ckb::ConvDirection direction = ckb::ConvDirection::FORWARD;
|
||||
ckb::DataType data_type = FP16;
|
||||
ckb::DataType accumulation_data_type = FP32;
|
||||
ckb::ConvDirection direction = FORWARD;
|
||||
ConvTensorWithOp input = {
|
||||
.config = {ckb::TensorLayout::GNHWC, ckb::DataType::FP16},
|
||||
.config = {GNHWC, FP16},
|
||||
};
|
||||
ConvTensorWithOp weight = {.config = {ckb::TensorLayout::GKYXC, ckb::DataType::FP16}};
|
||||
ConvTensorWithOp output = {.config = {ckb::TensorLayout::GNHWK, ckb::DataType::FP16},
|
||||
.operation = {ckb::ElementwiseOperation::SCALE}};
|
||||
ConvTensorWithOp weight = {.config = {GKYXC, FP16}};
|
||||
ConvTensorWithOp output = {.config = {GNHWK, FP16}, .operation = {SCALE}};
|
||||
};
|
||||
static_assert(ckb::ConvSignatureDescriptor<ConvSignatureWithOptionalParams>);
|
||||
|
||||
struct ConvSignatureWithInvalidOptionalParams
|
||||
{
|
||||
using enum ckb::DataType;
|
||||
using enum ckb::TensorLayout;
|
||||
|
||||
int spatial_dim = 2;
|
||||
ckb::DataType data_type = ckb::DataType::FP16;
|
||||
ckb::DataType accumulation_data_type = ckb::DataType::FP32;
|
||||
ConvTensorWithInvalidOp input = {.config = {ckb::TensorLayout::GNHWC}};
|
||||
ConvTensorWithInvalidOp weight = {.config = {ckb::TensorLayout::GKYXC}};
|
||||
ConvTensorWithInvalidOp output = {.config = {ckb::TensorLayout::GNHWK}};
|
||||
ckb::DataType data_type = FP16;
|
||||
ckb::DataType accumulation_data_type = FP32;
|
||||
ConvTensorWithInvalidOp input = {.config = {GNHWC}};
|
||||
ConvTensorWithInvalidOp weight = {.config = {GKYXC}};
|
||||
ConvTensorWithInvalidOp output = {.config = {GNHWK}};
|
||||
};
|
||||
static_assert(!ckb::ConvSignatureDescriptor<ConvSignatureWithInvalidOptionalParams>);
|
||||
|
||||
|
||||
@@ -262,14 +262,14 @@ TEST(InstanceTraits, V3InstanceStringReturnsCorrectFormat)
|
||||
",2" // ABlockTransferSrcVectorDim
|
||||
",8" // ABlockTransferSrcScalarPerVector
|
||||
",8" // ABlockTransferDstScalarPerVector_AK1
|
||||
",1" // ABlockLdsExtraM
|
||||
",true" // ABlockLdsExtraM
|
||||
",Seq(4,64,1)" // BBlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
|
||||
",2" // BBlockTransferSrcVectorDim
|
||||
",8" // BBlockTransferSrcScalarPerVector
|
||||
",8" // BBlockTransferDstScalarPerVector_BK1
|
||||
",1" // BBlockLdsExtraN
|
||||
",true" // BBlockLdsExtraN
|
||||
",1" // CShuffleMXdlPerWavePerShuffle
|
||||
",1" // CShuffleNXdlPerWavePerShuffle
|
||||
",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths
|
||||
@@ -377,14 +377,14 @@ TEST(InstanceTraits, BaseInstanceStringReturnsCorrectFormat)
|
||||
",2" // ABlockTransferSrcVectorDim
|
||||
",8" // ABlockTransferSrcScalarPerVector
|
||||
",8" // ABlockTransferDstScalarPerVector_AK1
|
||||
",1" // ABlockLdsExtraM
|
||||
",true" // ABlockLdsExtraM
|
||||
",Seq(4,64,1)" // BBlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
|
||||
",2" // BBlockTransferSrcVectorDim
|
||||
",8" // BBlockTransferSrcScalarPerVector
|
||||
",8" // BBlockTransferDstScalarPerVector_BK1
|
||||
",1" // BBlockLdsExtraN
|
||||
",true" // BBlockLdsExtraN
|
||||
",1" // CShuffleMXdlPerWavePerShuffle
|
||||
",1" // CShuffleNXdlPerWavePerShuffle
|
||||
",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths
|
||||
@@ -492,14 +492,14 @@ TEST(InstanceTraits, LargeTensorInstanceStringReturnsCorrectFormat)
|
||||
",2" // ABlockTransferSrcVectorDim
|
||||
",8" // ABlockTransferSrcScalarPerVector
|
||||
",8" // ABlockTransferDstScalarPerVector_AK1
|
||||
",1" // ABlockLdsExtraM
|
||||
",true" // ABlockLdsExtraM
|
||||
",Seq(4,64,1)" // BBlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
|
||||
",2" // BBlockTransferSrcVectorDim
|
||||
",8" // BBlockTransferSrcScalarPerVector
|
||||
",8" // BBlockTransferDstScalarPerVector_BK1
|
||||
",1" // BBlockLdsExtraN
|
||||
",true" // BBlockLdsExtraN
|
||||
",1" // CShuffleMXdlPerWavePerShuffle
|
||||
",1" // CShuffleNXdlPerWavePerShuffle
|
||||
",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths
|
||||
|
||||
@@ -60,14 +60,14 @@ std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"
|
||||
",2" // ABlockTransferSrcVectorDim
|
||||
",1" // ABlockTransferSrcScalarPerVector
|
||||
",8" // ABlockTransferDstScalarPerVector_AK1
|
||||
",1" // ABlockLdsExtraM
|
||||
",true" // ABlockLdsExtraM
|
||||
",Seq(4,16,1)" // BBlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
|
||||
",2" // BBlockTransferSrcVectorDim
|
||||
",1" // BBlockTransferSrcScalarPerVector
|
||||
",8" // BBlockTransferDstScalarPerVector_BK1
|
||||
",1" // BBlockLdsExtraN
|
||||
",true" // BBlockLdsExtraN
|
||||
",1" // CShuffleMXdlPerWavePerShuffle
|
||||
",1" // CShuffleNXdlPerWavePerShuffle
|
||||
",Seq(1,16,1,4)" // CDEBlockTransferClusterLengths
|
||||
|
||||
@@ -60,14 +60,14 @@ std::string expected_str = "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Ten
|
||||
",2" // ABlockTransferSrcVectorDim
|
||||
",1" // ABlockTransferSrcScalarPerVector
|
||||
",8" // ABlockTransferDstScalarPerVector_AK1
|
||||
",1" // ABlockLdsExtraM
|
||||
",true" // ABlockLdsExtraM
|
||||
",Seq(4,16,1)" // BBlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
|
||||
",2" // BBlockTransferSrcVectorDim
|
||||
",1" // BBlockTransferSrcScalarPerVector
|
||||
",8" // BBlockTransferDstScalarPerVector_BK1
|
||||
",1" // BBlockLdsExtraN
|
||||
",true" // BBlockLdsExtraN
|
||||
",1" // CShuffleMXdlPerWavePerShuffle
|
||||
",1" // CShuffleNXdlPerWavePerShuffle
|
||||
",Seq(1,16,1,4)" // CDEBlockTransferClusterLengths
|
||||
|
||||
@@ -60,14 +60,14 @@ std::string expected_str = "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"
|
||||
",2" // ABlockTransferSrcVectorDim
|
||||
",8" // ABlockTransferSrcScalarPerVector
|
||||
",8" // ABlockTransferDstScalarPerVector_AK1
|
||||
",0" // ABlockLdsExtraM
|
||||
",false" // ABlockLdsExtraM
|
||||
",Seq(8,32,1)" // BBlockTransferThreadClusterLengths
|
||||
",Seq(1,0,2)" // BBlockTransferThreadClusterArrangeOrder
|
||||
",Seq(1,0,2)" // BBlockTransferSrcAccessOrder
|
||||
",2" // BBlockTransferSrcVectorDim
|
||||
",8" // BBlockTransferSrcScalarPerVector
|
||||
",8" // BBlockTransferDstScalarPerVector_BK1
|
||||
",0" // BBlockLdsExtraN
|
||||
",false" // BBlockLdsExtraN
|
||||
",1" // CShuffleMXdlPerWavePerShuffle
|
||||
",1" // CShuffleNXdlPerWavePerShuffle
|
||||
",Seq(1,32,1,8)" // CDEBlockTransferClusterLengths
|
||||
|
||||
@@ -34,8 +34,8 @@ TEST(InstanceSet, FromFactory)
|
||||
const auto* el =
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<2,NHWGC,GKYXC,EmptyTuple,NHWGK,fp16,fp16,"
|
||||
"fp32,fp16,EmptyTuple,fp16,PassThrough,PassThrough,PassThrough,Default,MNKPadding,1,128,"
|
||||
"128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,1,Seq(4,32,1),Seq(1,0,2),"
|
||||
"Seq(1,0,2),2,8,8,1,1,1,Seq(1,16,1,8),8,fp16,fp16,Default,1>";
|
||||
"128,128,32,8,8,32,32,4,2,Seq(4,32,1),Seq(1,0,2),Seq(1,0,2),2,8,8,true,Seq(4,32,1),"
|
||||
"Seq(1,0,2),Seq(1,0,2),2,8,8,true,1,1,Seq(1,16,1,8),8,fp16,fp16,Default,1>";
|
||||
EXPECT_THAT(instances.instances, testing::Contains(el));
|
||||
}
|
||||
|
||||
|
||||
@@ -9,27 +9,34 @@
|
||||
|
||||
namespace {
|
||||
|
||||
namespace ckb = ::ck_tile::builder;
|
||||
using ::ck_tile::builder::DataType;
|
||||
using ::ck_tile::builder::ElementwiseOperation;
|
||||
using ::ck_tile::builder::TensorLayout;
|
||||
using ::ck_tile::builder::factory::internal::AuxiliaryTensorLayouts;
|
||||
using ::ck_tile::builder::factory::internal::ConvTensorLayouts;
|
||||
using ::ck_tile::builder::factory::internal::LayoutToCK;
|
||||
namespace ckb = ck_tile::builder;
|
||||
using ck_tile::builder::DataType;
|
||||
using ck_tile::builder::ElementwiseOperation;
|
||||
using ck_tile::builder::TensorLayout;
|
||||
using ck_tile::builder::factory::internal::AuxiliaryTensorLayouts;
|
||||
using ck_tile::builder::factory::internal::ConvTensorLayouts;
|
||||
using ck_tile::builder::factory::internal::LayoutToCK;
|
||||
using ck_tile::builder::test::ConvolutionTensor;
|
||||
using ck_tile::builder::test::ConvSignature;
|
||||
using ck_tile::builder::test::TensorConfig;
|
||||
using ck_tile::builder::test::TensorOperation;
|
||||
|
||||
using namespace ::ck_tile::builder::test;
|
||||
using enum ::ck_tile::builder::ConvDirection;
|
||||
namespace enums {
|
||||
using enum ck_tile::builder::ConvDirection;
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
using enum ck_tile::builder::DataType;
|
||||
} // namespace enums
|
||||
|
||||
TEST(ConvTensorLayout, AssignsLayoutsFor1D_NWGC_GKXC_NWGK)
|
||||
{
|
||||
static constexpr auto sig =
|
||||
ConvSignature<>{.spatial_dim = 1,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NWGK}}};
|
||||
using namespace enums;
|
||||
static constexpr auto sig = ConvSignature<>{.spatial_dim = 1,
|
||||
.direction = FORWARD,
|
||||
.data_type = FP16,
|
||||
.accumulation_data_type = FP32,
|
||||
.input = {.config = {.layout = NWGC}},
|
||||
.weight = {.config = {.layout = GKXC}},
|
||||
.output = {.config = {.layout = NWGK}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
|
||||
|
||||
@@ -41,14 +48,14 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NWGC_GKXC_NWGK)
|
||||
|
||||
TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKXC_NGKW)
|
||||
{
|
||||
static constexpr auto sig =
|
||||
ConvSignature<>{.spatial_dim = 1,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NGCW}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NGKW}}};
|
||||
using namespace enums;
|
||||
static constexpr auto sig = ConvSignature<>{.spatial_dim = 1,
|
||||
.direction = FORWARD,
|
||||
.data_type = FP16,
|
||||
.accumulation_data_type = FP32,
|
||||
.input = {.config = {.layout = NGCW}},
|
||||
.weight = {.config = {.layout = GKXC}},
|
||||
.output = {.config = {.layout = NGKW}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
|
||||
|
||||
@@ -60,14 +67,14 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKXC_NGKW)
|
||||
|
||||
TEST(ConvTensorLayout, AssignsLayoutsFor1D_GNWC_GKXC_GNWK)
|
||||
{
|
||||
static constexpr auto sig =
|
||||
ConvSignature<>{.spatial_dim = 1,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::GNWC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKXC}},
|
||||
.output = {.config = {.layout = TensorLayout::GNWK}}};
|
||||
using namespace enums;
|
||||
static constexpr auto sig = ConvSignature<>{.spatial_dim = 1,
|
||||
.direction = FORWARD,
|
||||
.data_type = FP16,
|
||||
.accumulation_data_type = FP32,
|
||||
.input = {.config = {.layout = GNWC}},
|
||||
.weight = {.config = {.layout = GKXC}},
|
||||
.output = {.config = {.layout = GNWK}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
|
||||
|
||||
@@ -79,14 +86,14 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_GNWC_GKXC_GNWK)
|
||||
|
||||
TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKCX_NGKW)
|
||||
{
|
||||
static constexpr auto sig =
|
||||
ConvSignature<>{.spatial_dim = 1,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NGCW}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKCX}},
|
||||
.output = {.config = {.layout = TensorLayout::NGKW}}};
|
||||
using namespace enums;
|
||||
static constexpr auto sig = ConvSignature<>{.spatial_dim = 1,
|
||||
.direction = FORWARD,
|
||||
.data_type = FP16,
|
||||
.accumulation_data_type = FP32,
|
||||
.input = {.config = {.layout = NGCW}},
|
||||
.weight = {.config = {.layout = GKCX}},
|
||||
.output = {.config = {.layout = NGKW}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1, FORWARD>;
|
||||
|
||||
@@ -98,14 +105,14 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKCX_NGKW)
|
||||
|
||||
TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKYXC_NGKHW)
|
||||
{
|
||||
static constexpr auto sig =
|
||||
ConvSignature<>{.spatial_dim = 2,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NGCHW}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NGKHW}}};
|
||||
using namespace enums;
|
||||
static constexpr auto sig = ConvSignature<>{.spatial_dim = 2,
|
||||
.direction = FORWARD,
|
||||
.data_type = FP16,
|
||||
.accumulation_data_type = FP32,
|
||||
.input = {.config = {.layout = NGCHW}},
|
||||
.weight = {.config = {.layout = GKYXC}},
|
||||
.output = {.config = {.layout = NGKHW}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
|
||||
|
||||
@@ -117,14 +124,14 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKYXC_NGKHW)
|
||||
|
||||
TEST(ConvTensorLayout, AssignsLayoutsFor2D_NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
static constexpr auto sig =
|
||||
ConvSignature<>{.spatial_dim = 2,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NHWGK}}};
|
||||
using namespace enums;
|
||||
static constexpr auto sig = ConvSignature<>{.spatial_dim = 2,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = NHWGC}},
|
||||
.weight = {.config = {.layout = GKYXC}},
|
||||
.output = {.config = {.layout = NHWGK}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
|
||||
|
||||
@@ -136,14 +143,14 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NHWGC_GKYXC_NHWGK)
|
||||
|
||||
TEST(ConvTensorLayout, AssignsLayoutsFor2D_GNHWC_GKYXC_GNHWK)
|
||||
{
|
||||
static constexpr auto sig =
|
||||
ConvSignature<>{.spatial_dim = 2,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::GNHWK}}};
|
||||
using namespace enums;
|
||||
static constexpr auto sig = ConvSignature<>{.spatial_dim = 2,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = GNHWC}},
|
||||
.weight = {.config = {.layout = GKYXC}},
|
||||
.output = {.config = {.layout = GNHWK}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
|
||||
|
||||
@@ -155,14 +162,14 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_GNHWC_GKYXC_GNHWK)
|
||||
|
||||
TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKCYX_NGKHW)
|
||||
{
|
||||
static constexpr auto sig =
|
||||
ConvSignature<>{.spatial_dim = 2,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NGCHW}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKCYX}},
|
||||
.output = {.config = {.layout = TensorLayout::NGKHW}}};
|
||||
using namespace enums;
|
||||
static constexpr auto sig = ConvSignature<>{.spatial_dim = 2,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = NGCHW}},
|
||||
.weight = {.config = {.layout = GKCYX}},
|
||||
.output = {.config = {.layout = NGKHW}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2, FORWARD>;
|
||||
|
||||
@@ -174,14 +181,14 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKCYX_NGKHW)
|
||||
|
||||
TEST(ConvTensorLayout, AssignsLayoutsFor3D_NGCDHW_GKCZYX_NGKDHW)
|
||||
{
|
||||
static constexpr auto sig =
|
||||
ConvSignature<>{.spatial_dim = 3,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NGCDHW}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKCZYX}},
|
||||
.output = {.config = {.layout = TensorLayout::NGKDHW}}};
|
||||
using namespace enums;
|
||||
static constexpr auto sig = ConvSignature<>{.spatial_dim = 3,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = NGCDHW}},
|
||||
.weight = {.config = {.layout = GKCZYX}},
|
||||
.output = {.config = {.layout = NGKDHW}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 3, FORWARD>;
|
||||
|
||||
@@ -193,14 +200,14 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_NGCDHW_GKCZYX_NGKDHW)
|
||||
|
||||
TEST(ConvTensorLayout, AssignsLayoutsFor3D_NDHWGC_GKZYXC_NDHWGK)
|
||||
{
|
||||
static constexpr auto sig =
|
||||
ConvSignature<>{.spatial_dim = 3,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NDHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKZYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NDHWGK}}};
|
||||
using namespace enums;
|
||||
static constexpr auto sig = ConvSignature<>{.spatial_dim = 3,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = NDHWGC}},
|
||||
.weight = {.config = {.layout = GKZYXC}},
|
||||
.output = {.config = {.layout = NDHWGK}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 3, FORWARD>;
|
||||
|
||||
@@ -212,14 +219,14 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_NDHWGC_GKZYXC_NDHWGK)
|
||||
|
||||
TEST(ConvTensorLayout, AssignsLayoutsFor3D_GNDHWC_GKZYXC_GNDHWK)
|
||||
{
|
||||
static constexpr auto sig =
|
||||
ConvSignature<>{.spatial_dim = 3,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::GNDHWC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKZYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::GNDHWK}}};
|
||||
using namespace enums;
|
||||
static constexpr auto sig = ConvSignature<>{.spatial_dim = 3,
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = GNDHWC}},
|
||||
.weight = {.config = {.layout = GKZYXC}},
|
||||
.output = {.config = {.layout = GNDHWK}}};
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 3, FORWARD>;
|
||||
|
||||
@@ -261,8 +268,10 @@ struct MockAuxiliaryTensorConfig
|
||||
|
||||
TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_K_Layout)
|
||||
{
|
||||
using namespace enums;
|
||||
|
||||
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
|
||||
MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided}};
|
||||
MockAuxiliaryTensorConfig{.layout = G_K_strided}};
|
||||
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
|
||||
|
||||
@@ -273,6 +282,8 @@ TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_K_Layout)
|
||||
|
||||
TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithGC_Layout)
|
||||
{
|
||||
using namespace enums;
|
||||
|
||||
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
|
||||
MockAuxiliaryTensorConfig{.layout = TensorLayout::GC}};
|
||||
|
||||
@@ -285,8 +296,10 @@ TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithGC_Layout)
|
||||
|
||||
TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_C_Layout)
|
||||
{
|
||||
using namespace enums;
|
||||
|
||||
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
|
||||
MockAuxiliaryTensorConfig{.layout = TensorLayout::G_C_strided}};
|
||||
MockAuxiliaryTensorConfig{.layout = G_C_strided}};
|
||||
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
|
||||
|
||||
@@ -297,9 +310,11 @@ TEST(AuxiliaryTensorLayoutIntegration, SingleBiasTensorWithG_C_Layout)
|
||||
|
||||
TEST(AuxiliaryTensorLayoutIntegration, TwoAuxiliaryTensors)
|
||||
{
|
||||
using namespace enums;
|
||||
|
||||
static constexpr std::array<MockAuxiliaryTensorConfig, 2> aux_configs = {
|
||||
MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided},
|
||||
MockAuxiliaryTensorConfig{.layout = TensorLayout::GC}};
|
||||
MockAuxiliaryTensorConfig{.layout = GC}};
|
||||
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
|
||||
|
||||
@@ -311,10 +326,12 @@ TEST(AuxiliaryTensorLayoutIntegration, TwoAuxiliaryTensors)
|
||||
|
||||
TEST(AuxiliaryTensorLayoutIntegration, ThreeAuxiliaryTensors)
|
||||
{
|
||||
using namespace enums;
|
||||
|
||||
static constexpr std::array<MockAuxiliaryTensorConfig, 3> aux_configs = {
|
||||
MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided},
|
||||
MockAuxiliaryTensorConfig{.layout = TensorLayout::GC},
|
||||
MockAuxiliaryTensorConfig{.layout = TensorLayout::G_C_strided}};
|
||||
MockAuxiliaryTensorConfig{.layout = G_K_strided},
|
||||
MockAuxiliaryTensorConfig{.layout = GC},
|
||||
MockAuxiliaryTensorConfig{.layout = G_C_strided}};
|
||||
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 2, FORWARD>;
|
||||
|
||||
@@ -327,8 +344,10 @@ TEST(AuxiliaryTensorLayoutIntegration, ThreeAuxiliaryTensors)
|
||||
|
||||
TEST(AuxiliaryTensorLayoutIntegration, WorksWith1DConvolution)
|
||||
{
|
||||
using namespace enums;
|
||||
|
||||
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
|
||||
MockAuxiliaryTensorConfig{.layout = TensorLayout::G_K_strided}};
|
||||
MockAuxiliaryTensorConfig{.layout = G_K_strided}};
|
||||
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 1, FORWARD>;
|
||||
|
||||
@@ -339,8 +358,10 @@ TEST(AuxiliaryTensorLayoutIntegration, WorksWith1DConvolution)
|
||||
|
||||
TEST(AuxiliaryTensorLayoutIntegration, WorksWith3DConvolution)
|
||||
{
|
||||
using namespace enums;
|
||||
|
||||
static constexpr std::array<MockAuxiliaryTensorConfig, 1> aux_configs = {
|
||||
MockAuxiliaryTensorConfig{.layout = TensorLayout::GC}};
|
||||
MockAuxiliaryTensorConfig{.layout = GC}};
|
||||
|
||||
using AuxLayouts = AuxiliaryTensorLayouts<aux_configs, 3, FORWARD>;
|
||||
|
||||
@@ -351,7 +372,8 @@ TEST(AuxiliaryTensorLayoutIntegration, WorksWith3DConvolution)
|
||||
|
||||
TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasG_K)
|
||||
{
|
||||
using OutputOp = TensorOperation<TensorConfig{.layout = TensorLayout::G_K_strided}>;
|
||||
using namespace enums;
|
||||
using OutputOp = TensorOperation<TensorConfig{.layout = G_K_strided}>;
|
||||
|
||||
static constexpr auto sig =
|
||||
ConvSignature<ConvolutionTensor<>, ConvolutionTensor<>, ConvolutionTensor<OutputOp>>{
|
||||
@@ -359,9 +381,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasG_K)
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NGCHW}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NGKHW},
|
||||
.input = {.config = {.layout = NGCHW}},
|
||||
.weight = {.config = {.layout = GKYXC}},
|
||||
.output = {.config = {.layout = NGKHW},
|
||||
.operation =
|
||||
OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}};
|
||||
|
||||
@@ -377,7 +399,8 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasG_K)
|
||||
|
||||
TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasGC)
|
||||
{
|
||||
using OutputOp = TensorOperation<TensorConfig{.layout = TensorLayout::GC}>;
|
||||
using namespace enums;
|
||||
using OutputOp = TensorOperation<TensorConfig{.layout = GC}>;
|
||||
|
||||
static constexpr auto sig =
|
||||
ConvSignature<ConvolutionTensor<>, ConvolutionTensor<>, ConvolutionTensor<OutputOp>>{
|
||||
@@ -385,9 +408,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasGC)
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::BF16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NHWGK},
|
||||
.input = {.config = {.layout = NHWGC}},
|
||||
.weight = {.config = {.layout = GKYXC}},
|
||||
.output = {.config = {.layout = NHWGK},
|
||||
.operation =
|
||||
OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}};
|
||||
|
||||
@@ -403,8 +426,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasGC)
|
||||
|
||||
TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithTwoAuxiliaryTensors)
|
||||
{
|
||||
using OutputOp = TensorOperation<TensorConfig{.layout = TensorLayout::G_K_strided},
|
||||
TensorConfig{.layout = TensorLayout::GC}>;
|
||||
using namespace enums;
|
||||
using OutputOp =
|
||||
TensorOperation<TensorConfig{.layout = G_K_strided}, TensorConfig{.layout = GC}>;
|
||||
|
||||
static constexpr auto sig =
|
||||
ConvSignature<ConvolutionTensor<>, ConvolutionTensor<>, ConvolutionTensor<OutputOp>>{
|
||||
@@ -412,9 +436,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithTwoAuxiliaryTensors)
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::GNHWC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::GNHWK},
|
||||
.input = {.config = {.layout = GNHWC}},
|
||||
.weight = {.config = {.layout = GKYXC}},
|
||||
.output = {.config = {.layout = GNHWK},
|
||||
.operation = OutputOp{.elementwise_operation =
|
||||
ElementwiseOperation::SCALEADD_SCALEADD_RELU}}};
|
||||
|
||||
@@ -431,7 +455,8 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithTwoAuxiliaryTensors)
|
||||
|
||||
TEST(ConvTensorLayoutsWithAuxiliary, Conv1DWithBias)
|
||||
{
|
||||
using OutputOp = TensorOperation<TensorConfig{.layout = TensorLayout::G_K_strided}>;
|
||||
using namespace enums;
|
||||
using OutputOp = TensorOperation<TensorConfig{.layout = G_K_strided}>;
|
||||
|
||||
static constexpr auto sig =
|
||||
ConvSignature<ConvolutionTensor<>, ConvolutionTensor<>, ConvolutionTensor<OutputOp>>{
|
||||
@@ -439,9 +464,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv1DWithBias)
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP32,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NWGK},
|
||||
.input = {.config = {.layout = NWGC}},
|
||||
.weight = {.config = {.layout = GKXC}},
|
||||
.output = {.config = {.layout = NWGK},
|
||||
.operation =
|
||||
OutputOp{.elementwise_operation = ElementwiseOperation::SCALE}}};
|
||||
|
||||
@@ -457,7 +482,8 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv1DWithBias)
|
||||
|
||||
TEST(ConvTensorLayoutsWithAuxiliary, Conv3DWithBias)
|
||||
{
|
||||
using OutputOp = TensorOperation<TensorConfig{.layout = TensorLayout::G_C_strided}>;
|
||||
using namespace enums;
|
||||
using OutputOp = TensorOperation<TensorConfig{.layout = G_C_strided}>;
|
||||
|
||||
static constexpr auto sig =
|
||||
ConvSignature<ConvolutionTensor<>, ConvolutionTensor<>, ConvolutionTensor<OutputOp>>{
|
||||
@@ -465,9 +491,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv3DWithBias)
|
||||
.direction = FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NDHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKZYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NDHWGK},
|
||||
.input = {.config = {.layout = NDHWGC}},
|
||||
.weight = {.config = {.layout = GKZYXC}},
|
||||
.output = {.config = {.layout = NDHWGK},
|
||||
.operation = OutputOp{.elementwise_operation =
|
||||
ElementwiseOperation::BIAS_BNORM_CLAMP}}};
|
||||
|
||||
|
||||
346
experimental/builder/test/utils/conv_algorithm_type_utils.hpp
Normal file
346
experimental/builder/test/utils/conv_algorithm_type_utils.hpp
Normal file
@@ -0,0 +1,346 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../impl/conv_algorithm_types.hpp"
|
||||
#include <sstream>
|
||||
#include <array>
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
|
||||
// Helper function to convert arrays to Seq(...) format
|
||||
template <typename T, size_t N>
|
||||
std::string array_to_seq(const std::array<T, N>& arr)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "Seq(";
|
||||
for(size_t i = 0; i < N; ++i)
|
||||
{
|
||||
if(i > 0)
|
||||
oss << ",";
|
||||
oss << arr[i];
|
||||
}
|
||||
oss << ")";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
// Base template - will cause compilation error for unsupported types
|
||||
template <typename T>
|
||||
std::string to_string(T)
|
||||
{
|
||||
static_assert(sizeof(T) == 0, "Unsupported type");
|
||||
return "";
|
||||
}
|
||||
|
||||
// Template specializations for enum types
|
||||
|
||||
template <>
|
||||
inline std::string to_string<PipelineVersion>(PipelineVersion t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<PipelineScheduler>(PipelineScheduler t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvFwdSpecialization>(ConvFwdSpecialization t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<GemmSpecialization>(GemmSpecialization t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
// Template specializations for struct types
|
||||
|
||||
template <>
|
||||
inline std::string to_string<MNK<size_t>>(MNK<size_t> t)
|
||||
{
|
||||
return array_to_seq(std::array<size_t, 3>{t.m, t.n, t.k});
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ThreadBlock>(ThreadBlock t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t.block_size << "," << t.tile_size.m << "," << t.tile_size.n << "," << t.tile_size.k;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<GridwiseXdlGemm>(GridwiseXdlGemm t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t.ak1 << "," << t.bk1 << "," << t.m_per_xdl << "," << t.n_per_xdl << ","
|
||||
<< t.m_xdl_per_wave << "," << t.n_xdl_per_wave;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<GridwiseWmmaGemm>(GridwiseWmmaGemm t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t.k1 << "," << t.m_per_wmma << "," << t.n_per_wmma << "," << t.m_wmma_per_wave << ","
|
||||
<< t.n_wmma_per_wave;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<BlockGemm>(BlockGemm t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.scheduler) << "," << to_string(t.pipeline_version);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<BlockTransfer>(BlockTransfer t)
|
||||
{
|
||||
return array_to_seq(std::array<size_t, 3>{t.k0, t.m_n, t.k1});
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ThreadCluster>(ThreadCluster t)
|
||||
{
|
||||
return array_to_seq(
|
||||
std::array<size_t, 4>{t.m_block, t.m_wave_per_xdl, t.n_block, t.n_wave_per_xdl});
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<LdsTransfer>(LdsTransfer t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t.src_vector_dim << "," << t.src_scalar_per_vector << "," << t.lds_dst_scalar_per_vector
|
||||
<< "," << (t.lds_padding ? "true" : "false") << ","
|
||||
<< (t.is_direct_load ? "true" : "false");
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<AccessOrder>(AccessOrder t)
|
||||
{
|
||||
return array_to_seq(t.order);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<TransferAB>(TransferAB t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.block_transfer) << "," << to_string(t.block_transfer_access_order) << ","
|
||||
<< to_string(t.src_access_order) << "," << t.lds_transfer.src_vector_dim << ","
|
||||
<< t.lds_transfer.src_scalar_per_vector << "," << t.lds_transfer.lds_dst_scalar_per_vector
|
||||
<< "," << (t.lds_transfer.lds_padding ? "true" : "false");
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<TransferC>(TransferC t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t.epilogue.m_xdl_per_wave_per_shuffle << "," << t.epilogue.n_per_wave_per_shuffle << ","
|
||||
<< to_string(t.thread_cluster_dims) << "," << t.epilogue.scalar_per_vector;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<TransferABC>(TransferABC t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlThreadConfig>(DlThreadConfig t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t.k1 << "," << t.m1_per_thread << "," << t.n1_per_thread << "," << t.k_per_thread;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlThreadCluster>(DlThreadCluster t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << array_to_seq(t.m1_xs) << "," << array_to_seq(t.n1_xs);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlBlockTransfer>(DlBlockTransfer t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << array_to_seq(t.thread_slice_lengths) << "," << array_to_seq(t.thread_cluster_lengths)
|
||||
<< "," << array_to_seq(t.thread_cluster_arrange_order) << ","
|
||||
<< array_to_seq(t.src_access_order) << "," << array_to_seq(t.src_vector_tensor_lengths)
|
||||
<< "," << array_to_seq(t.src_vector_tensor_contiguous_dim_order) << ","
|
||||
<< array_to_seq(t.dst_vector_tensor_lengths);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlEpilogue>(DlEpilogue t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << array_to_seq(t.src_dst_access_order) << "," << t.src_dst_vector_dim << ","
|
||||
<< t.dst_scalar_per_vector;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlBlockTransferAB>(DlBlockTransferAB t)
|
||||
{
|
||||
return to_string(t.block_transfer);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlBlockTransferC>(DlBlockTransferC t)
|
||||
{
|
||||
return to_string(t.epilogue);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlTransferABC>(DlTransferABC t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.a) << "," << to_string(t.b) << "," << to_string(t.c);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
// Template specializations for factory wrapper types
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ThreadBlock_>(ThreadBlock_ t)
|
||||
{
|
||||
return to_string(t.thread_block);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<XdlGemm_>(XdlGemm_ t)
|
||||
{
|
||||
return to_string(t.gridwise_gemm);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<WmmaGemm_>(WmmaGemm_ t)
|
||||
{
|
||||
return to_string(t.gridwise_gemm);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<Transfer_>(Transfer_ t)
|
||||
{
|
||||
return to_string(t.transfer);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvSpecialization_>(ConvSpecialization_ t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(t.fwd_specialization) << "," << to_string(t.gemm_specialization);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<Prefetch_>(Prefetch_ t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << t.num_gemm_k_prefetch_stages << "," << t.num_groups_to_merge << ","
|
||||
<< to_string(t.loop_scheduler);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<BlockGemm_>(BlockGemm_ t)
|
||||
{
|
||||
return to_string(t.block_gemm);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlThreadConfig_>(DlThreadConfig_ t)
|
||||
{
|
||||
return to_string(t.thread_config);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlThreadCluster_>(DlThreadCluster_ t)
|
||||
{
|
||||
return to_string(t.thread_cluster);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<DlTransfer_>(DlTransfer_ t)
|
||||
{
|
||||
return to_string(t.transfer);
|
||||
}
|
||||
|
||||
// Template specializations for algorithm types
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>(
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<XdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>(
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<XdlGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>(
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << "," << to_string(static_cast<WmmaGemm_>(t))
|
||||
<< "," << to_string(static_cast<Transfer_>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK>(
|
||||
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK t)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << to_string(static_cast<ThreadBlock_>(t)) << ","
|
||||
<< to_string(static_cast<DlThreadConfig_>(t)) << ","
|
||||
<< to_string(static_cast<DlThreadCluster_>(t)) << ","
|
||||
<< to_string(static_cast<DlTransfer_>(t));
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string to_string<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor>(
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor t)
|
||||
{
|
||||
return to_string(t.base_algorithm);
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
@@ -109,65 +109,145 @@ struct BlockwiseGemmWmmaops_pipeline_base
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t ScaleSliceSizeN,
|
||||
template <index_t ScaleSliceSizeMN,
|
||||
index_t ScaleSliceStrideMN,
|
||||
index_t ScaleSliceSizeK,
|
||||
index_t NWaves,
|
||||
index_t ScaleBlockK,
|
||||
index_t NumberOfBuffers,
|
||||
index_t RegSizePerWmma,
|
||||
typename GridDesc,
|
||||
typename ThreadCopy,
|
||||
typename GridBuffer,
|
||||
typename ThreadStaticBuffer,
|
||||
typename BScaleThreadDesc>
|
||||
struct BScale
|
||||
typename ThreadDesc>
|
||||
struct ABScale
|
||||
{
|
||||
__device__ BScale(GridDesc b_scale_grid_desc_,
|
||||
ThreadCopy b_scale_thread_copy_,
|
||||
GridBuffer b_scale_grid_buf_)
|
||||
: b_scale_thread_copy(b_scale_thread_copy_),
|
||||
b_scale_grid_desc(b_scale_grid_desc_),
|
||||
b_scale_grid_buf(b_scale_grid_buf_) {};
|
||||
__device__ ABScale(GridDesc scale_grid_desc_,
|
||||
ThreadCopy scale_thread_copy_,
|
||||
GridBuffer scale_grid_buf_)
|
||||
: scale_thread_copy(scale_thread_copy_),
|
||||
scale_grid_desc(scale_grid_desc_),
|
||||
scale_grid_buf(scale_grid_buf_) {};
|
||||
|
||||
static constexpr index_t num_scale_k_block = BScaleThreadDesc{}.GetLength(Number<1>{});
|
||||
static constexpr index_t num_scale_k_block = ThreadDesc{}.GetLength(Number<1>{});
|
||||
static constexpr index_t num_scale_krepeat = KRepeat / num_scale_k_block;
|
||||
|
||||
static constexpr auto b_scale_thread_desc = BScaleThreadDesc{};
|
||||
static constexpr index_t num_slice_mn = ScaleSliceSizeMN;
|
||||
static constexpr index_t num_slice_k = ScaleSliceSizeK;
|
||||
static constexpr index_t reg_size_per_wmma = RegSizePerWmma;
|
||||
|
||||
static constexpr auto b_scale_thread_copy_step =
|
||||
make_tuple(make_multi_index(NWaves * NPerWmma, 0),
|
||||
make_multi_index(-NPerBlock, 0),
|
||||
make_multi_index(-NPerBlock, (KPerBlock + ScaleBlockK - 1) / ScaleBlockK));
|
||||
static constexpr auto scale_thread_desc = ThreadDesc{};
|
||||
|
||||
static constexpr auto scale_thread_copy_step =
|
||||
make_tuple(make_multi_index(ScaleSliceStrideMN, 0),
|
||||
make_multi_index(-ScaleSliceSizeMN / RegSizePerWmma * ScaleSliceStrideMN, 0),
|
||||
make_multi_index(-ScaleSliceSizeMN / RegSizePerWmma * ScaleSliceStrideMN,
|
||||
ScaleSliceSizeK));
|
||||
|
||||
template <index_t NBuffer>
|
||||
__device__ void GlobalLoad(bool cond)
|
||||
{
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(n0, Number<0>{}),
|
||||
b_scale_thread_bufs(Number<NBuffer>{}));
|
||||
static_for<0, ScaleSliceSizeMN / RegSizePerWmma, 1>{}([&](auto m0) {
|
||||
scale_thread_copy.Run(scale_grid_desc,
|
||||
scale_grid_buf,
|
||||
scale_thread_desc,
|
||||
make_tuple(m0 * Number<RegSizePerWmma>{}, Number<0>{}),
|
||||
scale_thread_bufs(Number<NBuffer>{}));
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
b_scale_thread_copy_step.At(Number<0>{}));
|
||||
scale_thread_copy.MoveSrcSliceWindow(scale_grid_desc,
|
||||
scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
|
||||
if(cond)
|
||||
{
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
b_scale_thread_copy_step.At(Number<2>{}));
|
||||
scale_thread_copy.MoveSrcSliceWindow(scale_grid_desc,
|
||||
scale_thread_copy_step.At(Number<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
b_scale_thread_copy_step.At(Number<1>{}));
|
||||
scale_thread_copy.MoveSrcSliceWindow(scale_grid_desc,
|
||||
scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
ThreadCopy b_scale_thread_copy;
|
||||
GridDesc b_scale_grid_desc;
|
||||
GridBuffer b_scale_grid_buf;
|
||||
StaticallyIndexedArray<ThreadStaticBuffer, Number<NumberOfBuffers>{}> b_scale_thread_bufs;
|
||||
ThreadCopy scale_thread_copy;
|
||||
GridDesc scale_grid_desc;
|
||||
GridBuffer scale_grid_buf;
|
||||
StaticallyIndexedArray<ThreadStaticBuffer, Number<NumberOfBuffers>{}> scale_thread_bufs;
|
||||
};
|
||||
|
||||
template <typename AScaleStruct, typename BScaleStruct>
|
||||
struct CScale
|
||||
{
|
||||
__device__ CScale() {}
|
||||
|
||||
static constexpr auto reg_size_per_wmma =
|
||||
ck::is_same_v<BScaleStruct, Empty> && ck::is_same_v<AScaleStruct, Empty>
|
||||
? 1
|
||||
: wmma_gemm.GetRegSizePerWmma();
|
||||
static constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{},
|
||||
Number<AScaleStruct::num_slice_mn>{},
|
||||
Number<BScaleStruct::num_slice_mn>{}));
|
||||
using CScaleThreadDesc = decltype(c_scale_thread_desc);
|
||||
static constexpr auto num_scale_k_block = CScaleThreadDesc{}.GetLength(Number<0>{});
|
||||
static constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{});
|
||||
static constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{});
|
||||
using ThreadStaticBuffer = decltype(make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
c_scale_thread_desc.GetElementSpaceSize()));
|
||||
|
||||
__device__ void Load(AScaleStruct& a_scale_struct, BScaleStruct& b_scale_struct)
|
||||
{
|
||||
using AScaleThreadDesc = decltype(AScaleStruct::scale_thread_desc);
|
||||
using BScaleThreadDesc = decltype(BScaleStruct::scale_thread_desc);
|
||||
|
||||
static_for<0, num_scale_m_block, 1>{}([&](auto m0) {
|
||||
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
|
||||
constexpr index_t c_offset =
|
||||
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
|
||||
constexpr index_t a_offset =
|
||||
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
|
||||
constexpr index_t b_offset =
|
||||
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
|
||||
|
||||
c_scale_thread_bufs(I0)(Number<c_offset>{}) =
|
||||
a_scale_struct.scale_thread_bufs(I0)[Number<a_offset>{}] *
|
||||
b_scale_struct.scale_thread_bufs(I0)[Number<b_offset>{}];
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
__device__ void Clear()
|
||||
{
|
||||
static_for<0, reg_size_per_wmma, 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t k_index, index_t m_index, index_t n_index, typename CThreadBuf>
|
||||
__device__ void UpdateCThreadBuf(CThreadBuf& c_thread_buf)
|
||||
{
|
||||
static_for<0, reg_size_per_wmma, 1>{}([&](auto t) {
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m_index, n_index, t));
|
||||
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(make_tuple(
|
||||
k_index,
|
||||
(m_index * num_scale_m_block / MRepeat) % num_scale_m_block +
|
||||
(Number<t / (reg_size_per_wmma / AScaleStruct::reg_size_per_wmma)>{}) %
|
||||
AScaleStruct::reg_size_per_wmma,
|
||||
(n_index * num_scale_n_block / NRepeat) % num_scale_n_block));
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()[Number<t>{}] *
|
||||
type_convert<AccDataType>(c_scale_thread_bufs(I0)[Number<cscale_offset>{}]);
|
||||
});
|
||||
}
|
||||
|
||||
StaticallyIndexedArray<ThreadStaticBuffer, Number<1>{}> c_scale_thread_bufs;
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, AccDataType, 1, reg_size_per_wmma, true>
|
||||
c_thread_buf_per_scale;
|
||||
};
|
||||
|
||||
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
|
||||
|
||||
@@ -174,7 +174,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer,
|
||||
typename BScaleStruct>
|
||||
typename AScaleStruct,
|
||||
typename BScaleStruct,
|
||||
typename enable_if<ck::is_same_v<AScaleStruct, Empty>, bool>::type = false>
|
||||
__device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
@@ -188,7 +190,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
// BScaleThreadCopy
|
||||
AScaleStruct&,
|
||||
BScaleStruct& b_scale_struct,
|
||||
index_t num_loop,
|
||||
index_t num_loop_per_scale) const
|
||||
@@ -207,6 +209,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Scales global load
|
||||
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
|
||||
|
||||
// Local prefill 1
|
||||
@@ -217,6 +220,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
c_thread_buf.Clear();
|
||||
|
||||
auto blockwise_gemm_func = [&]() {
|
||||
// Local load
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
|
||||
@@ -245,7 +249,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(I0, n0, k0, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_scale_struct.b_scale_thread_bufs(
|
||||
b_scale_struct.scale_thread_bufs(
|
||||
I0)[Number<n0 * BScaleStruct::num_scale_k_block +
|
||||
k0 / BScaleStruct::num_scale_krepeat>{}],
|
||||
b_thread_desc_,
|
||||
@@ -366,6 +370,189 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer,
|
||||
typename AScaleStruct,
|
||||
typename BScaleStruct,
|
||||
typename enable_if<!ck::is_same_v<AScaleStruct, Empty> &&
|
||||
!ck::is_same_v<BScaleStruct, Empty>,
|
||||
bool>::type = false>
|
||||
__device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
AScaleStruct& a_scale_struct,
|
||||
BScaleStruct& b_scale_struct,
|
||||
index_t num_loop,
|
||||
index_t num_loop_per_scale) const
|
||||
{
|
||||
constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
|
||||
static constexpr auto NumScaleKBlock =
|
||||
Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{};
|
||||
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
|
||||
Base::a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
Base::b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
using CScaleStruct = typename Base::template CScale<AScaleStruct, BScaleStruct>;
|
||||
auto c_scale_struct = CScaleStruct{};
|
||||
|
||||
// Global prefetch 1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Scales global load
|
||||
a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
|
||||
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
|
||||
|
||||
c_scale_struct.Load(a_scale_struct, b_scale_struct);
|
||||
|
||||
// Local prefill 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
auto blockwise_gemm_func = [&]() {
|
||||
// Local load
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
Base::a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(I0, m0, k0, I0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
Base::a_thread_desc_,
|
||||
make_tuple(I0, m0, k0, I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
Base::b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(I0, n0, k0, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
Base::b_thread_desc_,
|
||||
make_tuple(I0, n0, k0, I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
|
||||
c_scale_struct.Clear();
|
||||
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
|
||||
|
||||
static_for<0, KInner, 1>{}([&](auto k_inner) {
|
||||
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
constexpr index_t k_index =
|
||||
kscale0 * (KRepeat / NumScaleKBlock) + k0;
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<Base::a_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / A_K1>{},
|
||||
m0,
|
||||
k_index,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % A_K1>{}))>{}];
|
||||
});
|
||||
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
constexpr index_t k_index =
|
||||
kscale0 * (KRepeat / NumScaleKBlock) + k0;
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<Base::b_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / B_K1>{},
|
||||
n0,
|
||||
k_index,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
wmma_gemm.Run(
|
||||
a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
|
||||
Number<0>{}));
|
||||
});
|
||||
});
|
||||
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
a_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
|
||||
b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
|
||||
|
||||
block_sync_lds();
|
||||
blockwise_gemm_func();
|
||||
|
||||
block_sync_lds();
|
||||
c_scale_struct.Load(a_scale_struct, b_scale_struct);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
i += 1;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Full)
|
||||
{
|
||||
block_sync_lds();
|
||||
blockwise_gemm_func();
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
// A[MRepeat, I1, I1, KPack]
|
||||
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
@@ -528,6 +715,23 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
return TailNumber::Full;
|
||||
}
|
||||
|
||||
template <typename AScaleStruct, typename BScaleStruct>
|
||||
struct KLoopParams
|
||||
{
|
||||
static constexpr auto KRepeatNoScale = 1;
|
||||
static constexpr auto NumScaleKBlock =
|
||||
Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{};
|
||||
static constexpr auto KRepeatPerNumScaleKBlock = KRepeatPerCluster / NumScaleKBlock;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KLoopParams<Empty, Empty>
|
||||
{
|
||||
static constexpr index_t KRepeatNoScale = KRepeatPerCluster;
|
||||
static constexpr index_t NumScaleKBlock = 1;
|
||||
static constexpr index_t KRepeatPerNumScaleKBlock = 1;
|
||||
};
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
@@ -543,7 +747,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer,
|
||||
typename BScaleStruct>
|
||||
typename AScaleStruct,
|
||||
typename BScaleStruct,
|
||||
typename enable_if<ck::is_same_v<AScaleStruct, Empty>, bool>::type = false>
|
||||
__device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
@@ -557,7 +763,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
// BScaleThreadCopy
|
||||
AScaleStruct&,
|
||||
BScaleStruct& b_scale_struct,
|
||||
index_t num_loop,
|
||||
index_t num_loop_per_scale) const
|
||||
@@ -576,6 +782,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Scales global load
|
||||
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
|
||||
|
||||
// Local prefill 1
|
||||
@@ -615,7 +822,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_scale_struct.b_scale_thread_bufs(I0)[Number<
|
||||
b_scale_struct.scale_thread_bufs(I0)[Number<
|
||||
n0 * BScaleStruct::num_scale_k_block +
|
||||
(k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}],
|
||||
b_thread_desc_,
|
||||
@@ -704,6 +911,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_setprio(0);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -982,7 +1190,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer,
|
||||
typename BScaleStruct>
|
||||
typename AScaleStruct,
|
||||
typename BScaleStruct,
|
||||
typename enable_if<ck::is_same_v<AScaleStruct, Empty>, bool>::type = false>
|
||||
__device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
@@ -996,7 +1206,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
BBlockBuffer&,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
// BScaleThreadCopy
|
||||
AScaleStruct&,
|
||||
BScaleStruct&,
|
||||
index_t num_loop,
|
||||
index_t) const
|
||||
@@ -1319,6 +1529,248 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer,
|
||||
typename AScaleStruct,
|
||||
typename BScaleStruct,
|
||||
typename enable_if<!ck::is_same_v<AScaleStruct, Empty> &&
|
||||
!ck::is_same_v<BScaleStruct, Empty>,
|
||||
bool>::type = false>
|
||||
__device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc&,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer&,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
AScaleStruct& a_scale_struct,
|
||||
BScaleStruct& b_scale_struct,
|
||||
index_t num_loop,
|
||||
index_t num_loop_per_scale) const
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
|
||||
static constexpr auto NumScaleKBlock =
|
||||
Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{};
|
||||
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
|
||||
|
||||
using CScaleStruct = typename Base::template CScale<AScaleStruct, BScaleStruct>;
|
||||
auto c_scale_struct = CScaleStruct{};
|
||||
|
||||
auto gemm_core_func = [&](auto reg_buf) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
|
||||
c_scale_struct.Clear();
|
||||
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
|
||||
static_for<0, KInner, 1>{}([&](auto k_inner) {
|
||||
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
constexpr index_t k_index =
|
||||
kscale0 * (KRepeat / NumScaleKBlock) + k0;
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / A_K1>{},
|
||||
m0,
|
||||
k_index,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % A_K1>{}))>{}];
|
||||
});
|
||||
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
constexpr index_t k_index =
|
||||
kscale0 * (KRepeat / NumScaleKBlock) + k0;
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / B_K1>{},
|
||||
I0,
|
||||
I0,
|
||||
n0,
|
||||
I0,
|
||||
k_index,
|
||||
Number<kk % B_K1>{}))>{}];
|
||||
});
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
wmma_gemm.Run(
|
||||
a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
|
||||
Number<0>{}));
|
||||
});
|
||||
});
|
||||
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
auto a_local_prefetch_func = [&]() {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(I0, m0, k0, I0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, m0, k0, I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
// Global prefetch A1 B1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_k0_n0_n1_n2_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I0));
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Scales global load
|
||||
a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
|
||||
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
c_scale_struct.Load(a_scale_struct, b_scale_struct);
|
||||
|
||||
// Local prefill A1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
|
||||
// Global prefetch A2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
// Local prefetch A1
|
||||
block_sync_lds();
|
||||
a_local_prefetch_func();
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
auto LoopFunc = [&](auto wmma_reg_buf, auto local_read_buf) {
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_k0_n0_n1_n2_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(local_read_buf));
|
||||
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, wmma_reg_buf);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
a_scale_struct.template GlobalLoad<0>(
|
||||
(i + 2 + wmma_reg_buf) % num_loop_per_scale == 0);
|
||||
b_scale_struct.template GlobalLoad<0>(
|
||||
(i + 2 + wmma_reg_buf) % num_loop_per_scale == 0);
|
||||
|
||||
gemm_core_func(wmma_reg_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// loop prefetch copy
|
||||
a_local_prefetch_func();
|
||||
|
||||
c_scale_struct.Load(a_scale_struct, b_scale_struct);
|
||||
|
||||
// HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
};
|
||||
|
||||
LoopFunc(I0, I1);
|
||||
LoopFunc(I1, I0);
|
||||
|
||||
i += 2;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_k0_n0_n1_n2_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I1));
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
|
||||
a_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
|
||||
b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
|
||||
|
||||
gemm_core_func(I0);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// tail Local Prefetch A1
|
||||
a_local_prefetch_func();
|
||||
|
||||
c_scale_struct.Load(a_scale_struct, b_scale_struct);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
gemm_core_func(I1);
|
||||
// Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle
|
||||
// latency
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
gemm_core_func(I0);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
static constexpr auto b_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<KPack / B_K1 / B_KRow>{},
|
||||
|
||||
@@ -123,6 +123,9 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
KInner,
|
||||
TransposeC>;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::I2;
|
||||
using Base::I3;
|
||||
|
||||
using Base::A_K1;
|
||||
using Base::A_KRow;
|
||||
@@ -322,7 +325,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(I0, n0, k0, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_scale_struct.b_scale_thread_bufs(
|
||||
b_scale_struct.scale_thread_bufs(
|
||||
I0)[Number<n0 * BScaleStruct::num_scale_k_block +
|
||||
k0 / BScaleStruct::num_scale_krepeat>{}],
|
||||
b_thread_desc_,
|
||||
@@ -348,7 +351,9 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer,
|
||||
typename BScaleStruct>
|
||||
typename AScaleStruct,
|
||||
typename BScaleStruct,
|
||||
typename enable_if<ck::is_same_v<AScaleStruct, Empty>, bool>::type = false>
|
||||
__device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
@@ -362,7 +367,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
// BScaleThreadCopy
|
||||
AScaleStruct&,
|
||||
BScaleStruct& b_scale_struct,
|
||||
index_t num_loop,
|
||||
index_t num_loop_per_scale) const
|
||||
@@ -383,6 +388,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Scales global load
|
||||
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
|
||||
|
||||
// Local prefill 1
|
||||
@@ -611,6 +617,339 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer,
|
||||
typename AScaleStruct,
|
||||
typename BScaleStruct,
|
||||
typename enable_if<!ck::is_same_v<AScaleStruct, Empty> &&
|
||||
!ck::is_same_v<BScaleStruct, Empty>,
|
||||
bool>::type = false>
|
||||
__device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
AScaleStruct& a_scale_struct,
|
||||
BScaleStruct& b_scale_struct,
|
||||
index_t num_loop,
|
||||
index_t num_loop_per_scale) const
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
|
||||
static constexpr auto NumScaleKBlock =
|
||||
Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{};
|
||||
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
using CScaleStruct = typename Base::template CScale<AScaleStruct, BScaleStruct>;
|
||||
auto c_scale_struct = CScaleStruct{};
|
||||
|
||||
// Global prefetch 1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Scales global load
|
||||
a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
|
||||
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
|
||||
|
||||
c_scale_struct.Load(a_scale_struct, b_scale_struct);
|
||||
|
||||
// Local prefill 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// Global prefetch 2, perform when at least 2 loops exist.
|
||||
if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full)
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
}
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// Local prefetch 1
|
||||
block_sync_lds();
|
||||
|
||||
auto local_load_func = [&]() {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(I0, m0, k0, I0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, m0, k0, I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(I0, n0, k0, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, k0, I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
local_load_func();
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Main body, perform when at least 3 loops exist.
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
a_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
|
||||
b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
|
||||
c_scale_struct.Clear();
|
||||
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
|
||||
static_for<0, KInner, 1>{}([&](auto k_inner) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
constexpr index_t k_index =
|
||||
kscale0 * (KRepeat / NumScaleKBlock) + k0;
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / A_K1>{},
|
||||
m0,
|
||||
k_index,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % A_K1>{}))>{}];
|
||||
});
|
||||
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
constexpr index_t k_index =
|
||||
kscale0 * (KRepeat / NumScaleKBlock) + k0;
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / B_K1>{},
|
||||
n0,
|
||||
k_index,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_scale_struct.c_thread_buf_per_scale
|
||||
.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
});
|
||||
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
c_scale_struct.Load(a_scale_struct, b_scale_struct);
|
||||
block_sync_lds();
|
||||
|
||||
local_load_func();
|
||||
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
i += 1;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
|
||||
// Pre-tail, perform when at least 2 loops exist.
|
||||
if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full)
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// No RunRead or MoveSrcSliceWindow here, already finished them all!
|
||||
a_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
|
||||
b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
|
||||
c_scale_struct.Clear();
|
||||
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
|
||||
static_for<0, KInner, 1>{}([&](auto k_inner) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
constexpr index_t k_index =
|
||||
kscale0 * (KRepeat / NumScaleKBlock) + k0;
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / A_K1>{},
|
||||
m0,
|
||||
k_index,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % A_K1>{}))>{}];
|
||||
});
|
||||
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
constexpr index_t k_index =
|
||||
kscale0 * (KRepeat / NumScaleKBlock) + k0;
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / B_K1>{},
|
||||
n0,
|
||||
k_index,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
wmma_gemm.Run(
|
||||
a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
|
||||
Number<0>{}));
|
||||
});
|
||||
});
|
||||
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
c_scale_struct.Load(a_scale_struct, b_scale_struct);
|
||||
block_sync_lds();
|
||||
|
||||
local_load_func();
|
||||
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
// Tail, always perform.
|
||||
{
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
|
||||
c_scale_struct.Clear();
|
||||
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
|
||||
static_for<0, KInner, 1>{}([&](auto k_inner) {
|
||||
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
constexpr index_t k_index =
|
||||
kscale0 * (KRepeat / NumScaleKBlock) + k0;
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / A_K1>{},
|
||||
m0,
|
||||
k_index,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % A_K1>{}))>{}];
|
||||
});
|
||||
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
constexpr index_t k_index =
|
||||
kscale0 * (KRepeat / NumScaleKBlock) + k0;
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / B_K1>{},
|
||||
n0,
|
||||
k_index,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
wmma_gemm.Run(
|
||||
a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
|
||||
Number<0>{}));
|
||||
});
|
||||
});
|
||||
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
// Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle
|
||||
// latency
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
using Base::a_thread_copy_;
|
||||
using Base::a_thread_desc_;
|
||||
|
||||
@@ -105,6 +105,353 @@ struct DeviceGemmMultipleD_BlockScale_BPreshuffle : public BaseOperator
|
||||
virtual int GetPreShuffleParameters() = 0;
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename AScaleType,
|
||||
typename BDataType,
|
||||
typename BScaleType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
index_t ScaleBlockM,
|
||||
index_t ScaleBlockN,
|
||||
index_t ScaleBlockK,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
const ck::index_t M,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t StrideA,
|
||||
const ck::index_t StrideB,
|
||||
const std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
const ck::index_t StrideE,
|
||||
const void* p_a_scale,
|
||||
const void* p_b_scale,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
index_t KBatch) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
|
||||
virtual int GetPreShuffleParameters() = 0;
|
||||
};
|
||||
|
||||
/// @brief Wrapper for backward compatibility that allows to use instances of
|
||||
/// DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK in contexts where
|
||||
// DeviceGemmMultipleD_BlockScale_BPreshuffle is expected.
|
||||
///
|
||||
/// @note The main area where it can be used is DeviceOperationInstanceFactory::GetInstances().
|
||||
/// The only difference between API of DeviceGemmMultipleD_BlockScale_BPreshuffle and
|
||||
// DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK is
|
||||
/// that DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK::MakeArgumentPointer requires
|
||||
// an additional parameter KBatch which is explicitly passed as 1 by this wrapper.
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename AScaleType,
|
||||
typename BDataType,
|
||||
typename BScaleType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
index_t ScaleBlockM,
|
||||
index_t ScaleBlockN,
|
||||
index_t ScaleBlockK,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceGemmMultipleD_BlockScale_BPreshuffleWrapper
|
||||
: public DeviceGemmMultipleD_BlockScale_BPreshuffle<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
AScaleType,
|
||||
BDataType,
|
||||
BScaleType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ScaleBlockM,
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
AScaleType,
|
||||
BDataType,
|
||||
BScaleType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ScaleBlockM,
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
|
||||
explicit DeviceGemmMultipleD_BlockScale_BPreshuffleWrapper(std::unique_ptr<DeviceOp> p_op)
|
||||
: p_op_(std::move(p_op))
|
||||
{
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return p_op_->IsSupportedArgument(p_arg);
|
||||
}
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
const ck::index_t M,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t StrideA,
|
||||
const ck::index_t StrideB,
|
||||
const std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
const ck::index_t StrideE,
|
||||
const void* p_a_scale,
|
||||
const void* p_b_scale,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) override
|
||||
{
|
||||
return p_op_->MakeArgumentPointer(p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
p_a_scale,
|
||||
p_b_scale,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
1); // KBatch
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return p_op_->MakeInvokerPointer();
|
||||
}
|
||||
|
||||
int GetPreShuffleParameters() override { return p_op_->GetPreShuffleParameters(); }
|
||||
|
||||
std::string GetTypeString() const override { return p_op_->GetTypeString(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<DeviceOp> p_op_;
|
||||
|
||||
#endif // __HIPCC_RTC__
|
||||
};
|
||||
|
||||
// GEMM:
|
||||
// input : A[M, K], B[K, N],
|
||||
// input : D0[M, N], D1[M, N], ...
|
||||
// output : E[M, N]
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
// Assume:
|
||||
// D0, D1, ... and E have the same layout
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename AScaleType,
|
||||
typename BDataType,
|
||||
typename BScaleType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
index_t ScaleBlockM,
|
||||
index_t ScaleBlockN,
|
||||
index_t ScaleBlockK,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceGemmMultipleD_ABScaleSplitK : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
const ck::index_t M,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t StrideA,
|
||||
const ck::index_t StrideB,
|
||||
const std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
const ck::index_t StrideE,
|
||||
const void* p_a_scale,
|
||||
const void* p_b_scale,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
index_t KBatch) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
|
||||
virtual void SetKBatch(BaseArgument* arg, int KBatch) const = 0;
|
||||
};
|
||||
|
||||
/// @brief Wrapper for backward compatibility that allows to use instances of
|
||||
/// DeviceGemmMultipleD_ABScaleSplitK in contexts where DeviceGemmMultipleD_ABScale is
|
||||
/// expected.
|
||||
///
|
||||
/// @note The main area where it can be used is DeviceOperationInstanceFactory::GetInstances().
|
||||
/// The only difference between API of DeviceGemmMultipleD_ABScale and
|
||||
/// DeviceGemmMultipleD_ABScaleSplitK is that
|
||||
/// DeviceGemmMultipleD_ABScaleSplitK::MakeArgumentPointer requires a additional parameter
|
||||
/// KBatch which is explicitly passed as 1 by this wrapper.
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename AScaleType,
|
||||
typename BDataType,
|
||||
typename BScaleType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
index_t ScaleBlockM,
|
||||
index_t ScaleBlockN,
|
||||
index_t ScaleBlockK,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceGemmMultipleD_ABScaleSplitKWrapper
|
||||
: public DeviceGemmMultipleD_ABScale<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
AScaleType,
|
||||
BDataType,
|
||||
BScaleType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ScaleBlockM,
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
|
||||
using DeviceOp = DeviceGemmMultipleD_ABScaleSplitK<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
AScaleType,
|
||||
BDataType,
|
||||
BScaleType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ScaleBlockM,
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
|
||||
explicit DeviceGemmMultipleD_ABScaleSplitKWrapper(std::unique_ptr<DeviceOp> p_op)
|
||||
: p_op_(std::move(p_op))
|
||||
{
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return p_op_->IsSupportedArgument(p_arg);
|
||||
}
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
const ck::index_t M,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t StrideA,
|
||||
const ck::index_t StrideB,
|
||||
const std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
const ck::index_t StrideE,
|
||||
const void* p_a_scale,
|
||||
const void* p_b_scale,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) override
|
||||
{
|
||||
return p_op_->MakeArgumentPointer(p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
p_a_scale,
|
||||
p_b_scale,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
1); // KBatch
|
||||
}
|
||||
|
||||
void SetKBatch(BaseArgument* arg, int KBatch) const override { p_op_->SetKBatch(arg, KBatch); }
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return p_op_->MakeInvokerPointer();
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override { return p_op_->GetTypeString(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<DeviceOp> p_op_;
|
||||
|
||||
#endif // __HIPCC_RTC__
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
@@ -93,7 +93,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
p_bs_grid_shift,
|
||||
karg.p_ds_grid,
|
||||
karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset,
|
||||
karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_k_split_offset,
|
||||
karg.p_a_scale_grid,
|
||||
karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_b_k_split_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
@@ -315,12 +316,13 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale
|
||||
};
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_b_scale<
|
||||
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_ab_scale<
|
||||
ALayout,
|
||||
BLayout,
|
||||
Tuple<>, // DsLayout
|
||||
CLayout,
|
||||
Tuple<ADataType>,
|
||||
void, // AScaleType
|
||||
Tuple<BDataType>,
|
||||
BScaleDataType,
|
||||
AccDataType,
|
||||
@@ -332,6 +334,7 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
0, // ScaleBlockM
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
MPerBlock,
|
||||
@@ -405,7 +408,9 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale
|
||||
std::array<index_t, 1>{StrideB_},
|
||||
std::array<index_t, 0>{}, // StrideDs_
|
||||
StrideC_,
|
||||
0, // StrideScaleA
|
||||
StrideScaleB_,
|
||||
nullptr,
|
||||
p_b_scale_grid_,
|
||||
k_batch_,
|
||||
a_element_op_,
|
||||
|
||||
@@ -0,0 +1,362 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename DsDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t ScaleBlockM, // scale block for M
|
||||
index_t ScaleBlockN, // scale block for N
|
||||
index_t ScaleBlockK, // scale block for K
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
bool PermuteA = false,
|
||||
bool PermuteB = false>
|
||||
struct DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3
|
||||
: public DeviceGemmMultipleD_ABScaleSplitK<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
ScaleBlockM,
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_ab_scale<
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
Tuple<ADataType>,
|
||||
AScaleDataType,
|
||||
Tuple<BDataType>,
|
||||
BScaleDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
ScaleBlockM,
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
PermuteA,
|
||||
PermuteB>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
|
||||
using DeviceGemmCommon =
|
||||
DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
|
||||
Tuple<ADataType>,
|
||||
Tuple<BDataType>,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
BlockSize,
|
||||
AK1,
|
||||
BK1,
|
||||
GemmSpec,
|
||||
CShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
// Invoker
|
||||
using Invoker = typename DeviceGemmCommon::Invoker;
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// with splitk the implementation doesn't work
|
||||
// when KRead % ScaleBlockK != 0, independently of K padding
|
||||
if(arg.KBatch > 1 && arg.KRead % ScaleBlockK != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return DeviceGemmCommon::IsSupportedArgument(arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
void SetKBatch(BaseArgument* base_arg, int KBatch) const override
|
||||
{
|
||||
auto& arg = *dynamic_cast<Argument*>(base_arg);
|
||||
arg.KBatch = KBatch;
|
||||
arg.KRead = GridwiseGemm::CalculateKRead(arg.K, KBatch);
|
||||
arg.KPadded = GridwiseGemm::CalculateKPadded(arg.K, KBatch);
|
||||
arg.AK0 = GridwiseGemm::CalculateAK0Padded(arg.K, KBatch);
|
||||
arg.BK0 = GridwiseGemm::CalculateBK0Padded(arg.K, KBatch);
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
CDataType* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
const std::array<index_t, NumDTensor> StrideDs,
|
||||
index_t StrideC,
|
||||
const BScaleDataType* p_a_scale,
|
||||
const BScaleDataType* p_b_scale,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation cde_element_op,
|
||||
index_t KBatch = 1)
|
||||
{
|
||||
index_t StrideScaleA = ck::is_same_v<ALayout, tensor_layout::gemm::RowMajor>
|
||||
? math::integer_divide_ceil(K, ScaleBlockK)
|
||||
: math::integer_divide_ceil(M, ScaleBlockM);
|
||||
|
||||
index_t StrideScaleB = ck::is_same_v<BLayout, ck::tensor_layout::gemm::ColumnMajor>
|
||||
? math::integer_divide_ceil(K, ScaleBlockK)
|
||||
: math::integer_divide_ceil(N, ScaleBlockN);
|
||||
|
||||
return Argument{std::array<const void*, 1>{p_a},
|
||||
std::array<const void*, 1>{p_b},
|
||||
p_ds,
|
||||
p_c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
std::array<index_t, 1>{StrideA},
|
||||
std::array<index_t, 1>{StrideB},
|
||||
StrideDs,
|
||||
StrideC,
|
||||
StrideScaleA,
|
||||
StrideScaleB,
|
||||
p_a_scale,
|
||||
p_b_scale,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
const std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
index_t StrideC,
|
||||
const void* p_a_scale,
|
||||
const void* p_b_scale,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
index_t KBatch = 1) override
|
||||
{
|
||||
index_t StrideScaleA = ck::is_same_v<ALayout, tensor_layout::gemm::RowMajor>
|
||||
? math::integer_divide_ceil(K, ScaleBlockK)
|
||||
: math::integer_divide_ceil(M, ScaleBlockM);
|
||||
|
||||
index_t StrideScaleB = ck::is_same_v<BLayout, ck::tensor_layout::gemm::ColumnMajor>
|
||||
? math::integer_divide_ceil(K, ScaleBlockK)
|
||||
: math::integer_divide_ceil(N, ScaleBlockN);
|
||||
|
||||
return std::make_unique<Argument>(std::array<const void*, 1>{p_a},
|
||||
std::array<const void*, 1>{p_b},
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
std::array<index_t, 1>{StrideA},
|
||||
std::array<index_t, 1>{StrideB},
|
||||
StrideDs,
|
||||
StrideC,
|
||||
StrideScaleA,
|
||||
StrideScaleB,
|
||||
static_cast<const AScaleDataType*>(p_a_scale),
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemm_ABScale_Wmma_CShuffleV3"
|
||||
<< "<"
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< std::string(ALayout::name)[0]
|
||||
<< std::string(BLayout::name)[0]
|
||||
<< std::string(CLayout::name)[0]
|
||||
<< ">"
|
||||
<< " BlkSize: "
|
||||
<< BlockSize << ", "
|
||||
<< "BlkTile: "
|
||||
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
|
||||
<< "WaveTile: "
|
||||
<< MPerWmma<<"x"<<NPerWmma << ", "
|
||||
<< "WaveMap: "
|
||||
<< MRepeat<<"x" << NRepeat<<", "
|
||||
<< "VmemReadVec: "
|
||||
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
|
||||
<< "KPack: "
|
||||
<< GridwiseGemm::KPack;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -18,52 +18,6 @@
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
TailNumber TailNum = TailNumber::Full>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_gemm_b_preshuffle_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if(defined(__gfx11__) || defined(__gfx12__))
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
|
||||
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
|
||||
(std::is_same_v<e_data_type, ck::half_t> ||
|
||||
std::is_same_v<e_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
const index_t num_k_per_block = math::integer_divide_ceil(karg.K, GridwiseGemm::KPack);
|
||||
const index_t k_id = blockIdx.z * num_k_per_block;
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
p_shared, splitk_batch_offset, karg, epilogue_args, k_id);
|
||||
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -202,270 +156,14 @@ struct DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
ComputeTypeB,
|
||||
true>; // IsBPreshuffle
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
/// @brief This function issues GPU kernel execution.
|
||||
/// @param arg The GPU kernel arguments.
|
||||
/// @param stream_config The HIP stream configuration helper structure.
|
||||
/// @return The kernel's average execution time (if time measurement is
|
||||
/// enabled).
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
index_t k_grain = arg.KBatch * KPerBlock;
|
||||
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
Argument arg_ = arg;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(
|
||||
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(
|
||||
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0);
|
||||
|
||||
std::array<std::size_t, 1> size_as_buffers;
|
||||
size_as_buffers[Number<0>{}] =
|
||||
a_grid_desc_ak0_m_ak1[Number<0>{}].GetElementSpaceSize() *
|
||||
sizeof(ADataType) / GridwiseGemm::APackedSize;
|
||||
|
||||
std::array<std::size_t, 1> size_bs_buffers;
|
||||
size_bs_buffers[Number<0>{}] =
|
||||
b_grid_desc_bk0_n_bk1[Number<0>{}].GetElementSpaceSize() *
|
||||
sizeof(BDataType) / GridwiseGemm::BPackedSize;
|
||||
|
||||
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
|
||||
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
|
||||
|
||||
std::array<std::size_t, GridwiseGemm::NumDTensor> size_ds_buffers;
|
||||
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
size_ds_buffers[i] =
|
||||
ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
|
||||
});
|
||||
|
||||
ck::utility::RotatingMemWrapperMultiABD<Argument,
|
||||
Tuple<ADataType>,
|
||||
Tuple<BDataType>,
|
||||
DsDataType>
|
||||
rotating_mem(arg_,
|
||||
stream_config.rotating_count,
|
||||
size_as_buffers,
|
||||
size_bs_buffers,
|
||||
size_ds_buffers);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck::utility::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(arg_.KBatch > 1)
|
||||
HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_e_grid,
|
||||
0,
|
||||
arg_.M * arg_.N * sizeof(EDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
run_flush_cache,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg_);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
HIP_CHECK_ERROR(hipMemsetAsync(arg.p_e_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(EDataType),
|
||||
stream_config.stream_id_));
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr index_t minimum_occupancy = []() {
|
||||
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}();
|
||||
|
||||
// ThreadwiseTensorSliceTransfer_v7r3 (used in ThreadGroupTensorSliceTransfer_v7r3) is
|
||||
// currently implemented in such a way that all SrcScalarPerVectors must be the same, so
|
||||
// if one of D matrices is column-major, then all SrcScalarPerVectors must be 1. On the
|
||||
// other hand, Split K for 16-bit outputs uses packed atomics so ScalarPerVectors cannot
|
||||
// be odd.
|
||||
constexpr bool AtomicsImplementationExists =
|
||||
!(std::is_same_v<EDataType, ck::half_t> || std::is_same_v<EDataType, ck::bhalf_t> ||
|
||||
std::is_same_v<EDataType, int8_t>) ||
|
||||
(CDEShuffleBlockTransferScalarPerVectors{}[0] % 2 == 0);
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
if constexpr(AtomicsImplementationExists)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
if constexpr(AtomicsImplementationExists)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
using Invoker = typename DeviceGemmCommon::Invoker;
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
return DeviceGemmCommon::IsSupportedArgument(arg);
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,360 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename DsDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t ScaleBlockM, // scale block for M
|
||||
index_t ScaleBlockN, // scale block for N
|
||||
index_t ScaleBlockK, // scale block for K
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerWmma,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
bool PermuteA = false,
|
||||
bool PermuteB = false>
|
||||
struct DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle
|
||||
: public DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
ScaleBlockM,
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
using AScaleLayout = tensor_layout::gemm::ColumnMajor;
|
||||
using BScaleLayout = BLayout;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_ab_scale<
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
Tuple<ADataType>,
|
||||
AScaleDataType,
|
||||
Tuple<BDataType>,
|
||||
BScaleDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
ScaleBlockM,
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
PermuteA,
|
||||
PermuteB,
|
||||
true,
|
||||
AScaleLayout,
|
||||
BScaleLayout>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
int GetPreShuffleParameters() override { return NPerWmma; }
|
||||
|
||||
using DeviceGemmCommon =
|
||||
DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
|
||||
Tuple<ADataType>,
|
||||
Tuple<BDataType>,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
BlockSize,
|
||||
AK1,
|
||||
BK1,
|
||||
GemmSpec,
|
||||
CShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
true>; // IsBPreshuffle
|
||||
|
||||
// Invoker
|
||||
using Invoker = typename DeviceGemmCommon::Invoker;
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// with splitk the implementation doesn't work
|
||||
// when KRead % ScaleBlockK != 0, independently of K padding
|
||||
if(arg.KBatch > 1 && arg.KRead % ScaleBlockK != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return DeviceGemmCommon::IsSupportedArgument(arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
const std::array<index_t, NumDTensor> StrideDs,
|
||||
index_t StrideC,
|
||||
const void* p_a_scale,
|
||||
const void* p_b_scale,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation cde_element_op,
|
||||
index_t KBatch)
|
||||
{
|
||||
index_t StrideScaleA = ck::is_same_v<AScaleLayout, tensor_layout::gemm::RowMajor>
|
||||
? math::integer_divide_ceil(K, ScaleBlockK)
|
||||
: math::integer_divide_ceil(M, ScaleBlockM);
|
||||
|
||||
index_t StrideScaleB = ck::is_same_v<BScaleLayout, ck::tensor_layout::gemm::ColumnMajor>
|
||||
? math::integer_divide_ceil(K, ScaleBlockK)
|
||||
: math::integer_divide_ceil(N, ScaleBlockN);
|
||||
|
||||
return Argument{std::array<const void*, 1>{p_a},
|
||||
std::array<const void*, 1>{p_b},
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_e),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
std::array<index_t, 1>{StrideA},
|
||||
std::array<index_t, 1>{StrideB},
|
||||
StrideDs,
|
||||
StrideC,
|
||||
StrideScaleA,
|
||||
StrideScaleB,
|
||||
static_cast<const AScaleDataType*>(p_a_scale),
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
const std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
index_t StrideC,
|
||||
const void* p_a_scale,
|
||||
const void* p_b_scale,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op,
|
||||
index_t KBatch) override
|
||||
{
|
||||
index_t StrideScaleA = ck::is_same_v<AScaleLayout, tensor_layout::gemm::RowMajor>
|
||||
? math::integer_divide_ceil(K, ScaleBlockK)
|
||||
: math::integer_divide_ceil(M, ScaleBlockM);
|
||||
|
||||
index_t StrideScaleB = ck::is_same_v<BScaleLayout, ck::tensor_layout::gemm::ColumnMajor>
|
||||
? math::integer_divide_ceil(K, ScaleBlockK)
|
||||
: math::integer_divide_ceil(N, ScaleBlockN);
|
||||
|
||||
return std::make_unique<Argument>(std::array<const void*, 1>{p_a},
|
||||
std::array<const void*, 1>{p_b},
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_e),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
std::array<index_t, 1>{StrideA},
|
||||
std::array<index_t, 1>{StrideB},
|
||||
StrideDs,
|
||||
StrideC,
|
||||
StrideScaleA,
|
||||
StrideScaleB,
|
||||
static_cast<const AScaleDataType*>(p_a_scale),
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle"
|
||||
<< "<"
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< std::string(ALayout::name)[0]
|
||||
<< std::string(BLayout::name)[0]
|
||||
<< std::string(CLayout::name)[0]
|
||||
<< ">"
|
||||
<< " BlkSize: "
|
||||
<< BlockSize << ", "
|
||||
<< "BlkTile: "
|
||||
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
|
||||
<< "WaveTile: "
|
||||
<< MPerWmma<<"x"<<NPerWmma << ", "
|
||||
<< "WaveMap: "
|
||||
<< MRepeat<<"x" << NRepeat<<", "
|
||||
<< "VmemReadVec: "
|
||||
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
|
||||
<< "KPack: "
|
||||
<< GridwiseGemm::KPack;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -262,6 +262,16 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
@@ -279,22 +289,47 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -315,6 +350,20 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
{
|
||||
auto& arg = *dynamic_cast<Argument*>(base_arg);
|
||||
arg.KBatch = KBatch;
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
arg.KRead = GridwiseGemm64::CalculateKRead(arg.K, KBatch);
|
||||
arg.KPadded = GridwiseGemm64::CalculateKPadded(arg.K, KBatch);
|
||||
arg.AK0 = GridwiseGemm64::CalculateAK0Padded(arg.K, KBatch);
|
||||
arg.BK0 = GridwiseGemm64::CalculateBK0Padded(arg.K, KBatch);
|
||||
}
|
||||
else
|
||||
{
|
||||
arg.KRead = GridwiseGemm32::CalculateKRead(arg.K, KBatch);
|
||||
arg.KPadded = GridwiseGemm32::CalculateKPadded(arg.K, KBatch);
|
||||
arg.AK0 = GridwiseGemm32::CalculateAK0Padded(arg.K, KBatch);
|
||||
arg.BK0 = GridwiseGemm32::CalculateBK0Padded(arg.K, KBatch);
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
@@ -325,6 +374,13 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// with splitk the implementation doesn't work
|
||||
// when KRead % ScaleBlockK != 0, independently of K padding
|
||||
if(arg.KBatch > 1 && arg.KRead % ScaleBlockK != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
|
||||
{
|
||||
return false;
|
||||
@@ -385,6 +441,14 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
index_t StrideScaleA = ck::is_same_v<ALayout, tensor_layout::gemm::RowMajor>
|
||||
? math::integer_divide_ceil(K, ScaleBlockK)
|
||||
: math::integer_divide_ceil(M, ScaleBlockM);
|
||||
|
||||
index_t StrideScaleB = ck::is_same_v<BLayout, ck::tensor_layout::gemm::ColumnMajor>
|
||||
? math::integer_divide_ceil(K, ScaleBlockK)
|
||||
: math::integer_divide_ceil(N, ScaleBlockN);
|
||||
|
||||
return Argument{static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
p_ds,
|
||||
@@ -396,6 +460,8 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideC,
|
||||
StrideScaleA,
|
||||
StrideScaleB,
|
||||
static_cast<const AScaleDataType*>(p_a_scale),
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
1,
|
||||
@@ -425,6 +491,14 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
index_t StrideScaleA = ck::is_same_v<ALayout, tensor_layout::gemm::RowMajor>
|
||||
? math::integer_divide_ceil(K, ScaleBlockK)
|
||||
: math::integer_divide_ceil(M, ScaleBlockM);
|
||||
|
||||
index_t StrideScaleB = ck::is_same_v<BLayout, ck::tensor_layout::gemm::ColumnMajor>
|
||||
? math::integer_divide_ceil(K, ScaleBlockK)
|
||||
: math::integer_divide_ceil(N, ScaleBlockN);
|
||||
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
p_ds,
|
||||
@@ -436,6 +510,8 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideC,
|
||||
StrideScaleA,
|
||||
StrideScaleB,
|
||||
static_cast<const AScaleDataType*>(p_a_scale),
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
1,
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
@@ -86,12 +86,13 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
|
||||
{
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_b_scale<
|
||||
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_ab_scale<
|
||||
ALayout,
|
||||
BLayout,
|
||||
Tuple<>, // DsLayout
|
||||
CLayout,
|
||||
Tuple<ADataType>,
|
||||
void, // AScaleType
|
||||
Tuple<BDataType>,
|
||||
BScaleDataType,
|
||||
AccDataType,
|
||||
@@ -103,6 +104,7 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
0, // ScaleBlockM
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
MPerBlock,
|
||||
@@ -207,7 +209,9 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
|
||||
std::array<index_t, 1>{StrideB},
|
||||
std::array<index_t, 0>{}, // StrideDs_
|
||||
StrideC,
|
||||
0, // StrideScaleA
|
||||
StrideScaleB,
|
||||
nullptr,
|
||||
p_b_scale,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
@@ -245,7 +249,9 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
|
||||
std::array<index_t, 1>{StrideB},
|
||||
std::array<index_t, 0>{}, // StrideDs_
|
||||
StrideC,
|
||||
0, // StrideScaleA
|
||||
StrideScaleB,
|
||||
nullptr, // p_a_scale
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
KBatch,
|
||||
a_element_op,
|
||||
|
||||
@@ -38,7 +38,8 @@ template <typename GridwiseGemm,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename ComputeTypeA,
|
||||
typename ComputeTypeB>
|
||||
typename ComputeTypeB,
|
||||
bool IsBPreShuffled = false>
|
||||
struct DeviceGemm_Wmma_CShuffleV3_Common
|
||||
{
|
||||
|
||||
@@ -189,61 +190,174 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
if constexpr(IsBPreShuffled)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if constexpr(AtomicsImplementationExists)
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
if constexpr(AtomicsImplementationExists)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// TODO: Implement
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if constexpr(AtomicsImplementationExists)
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
if constexpr(AtomicsImplementationExists)
|
||||
{
|
||||
const auto kernel = kernel_gemm_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(IsBPreShuffled)
|
||||
{
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
if constexpr(AtomicsImplementationExists)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
if constexpr(AtomicsImplementationExists)
|
||||
{
|
||||
const auto kernel = kernel_gemm_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -299,6 +413,14 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(IsBPreShuffled)
|
||||
{
|
||||
if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -388,11 +388,11 @@ struct ABTransferThreadTiles
|
||||
// 1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1
|
||||
return transform_tensor_descriptor(
|
||||
BlockDesc{},
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(Number<ABK0 / KRow>{}, KRow, Number<1>{})),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
|
||||
make_pass_through_transform(Number<ABK1>{})),
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
Number<ABK0 / KRow>{}, KRow, Number<KPack / KRow / ABK1>{})),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
|
||||
make_pass_through_transform(Number<ABK1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{}));
|
||||
}
|
||||
|
||||
@@ -895,8 +895,9 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// Empty BScale struct for the blockwise pipeline.
|
||||
using BScale = typename BlockwiseGemmPipe::Empty;
|
||||
auto b_scale_struct = BScale{};
|
||||
using ABScale = typename BlockwiseGemmPipe::Empty;
|
||||
auto a_scale_struct = ABScale{};
|
||||
auto b_scale_struct = ABScale{};
|
||||
|
||||
/*******************************************************************************/
|
||||
//
|
||||
@@ -919,6 +920,7 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
b0_block_buf,
|
||||
b0_block_slice_copy_step,
|
||||
acc0_thread_buf,
|
||||
a_scale_struct,
|
||||
b_scale_struct,
|
||||
KBlockMainLoop,
|
||||
1); // num_k_block_per_scale
|
||||
|
||||
@@ -618,8 +618,9 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapNBlockIndex>{}]);
|
||||
|
||||
// BScale struct (Empty)
|
||||
using BScale = typename BlockwiseGemmPipe::Empty;
|
||||
auto b_scale_struct = BScale{};
|
||||
using Scale = typename BlockwiseGemmPipe::Empty;
|
||||
auto a_scale_struct = Scale{};
|
||||
auto b_scale_struct = Scale{};
|
||||
|
||||
const index_t num_k_block_per_scale = GetKBlockPerScale();
|
||||
|
||||
@@ -627,6 +628,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
decltype(bs_grid_desc_bk0_n_bk1),
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(a_scale_struct),
|
||||
decltype(b_scale_struct),
|
||||
decltype(epilogue_args),
|
||||
HasMainKBlockLoop,
|
||||
@@ -646,6 +648,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
block_m_id,
|
||||
block_n_id,
|
||||
num_k_block_per_scale,
|
||||
a_scale_struct,
|
||||
b_scale_struct,
|
||||
epilogue_args,
|
||||
k_id);
|
||||
|
||||
@@ -23,6 +23,7 @@ template <typename ALayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename AsDataType,
|
||||
typename AScaleType,
|
||||
typename BsDataType,
|
||||
typename BScaleType,
|
||||
typename AccDataType,
|
||||
@@ -34,6 +35,7 @@ template <typename ALayout,
|
||||
typename CDEElementwiseOperation,
|
||||
tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t ScaleBlockM,
|
||||
index_t ScaleBlockN, // scale N
|
||||
index_t ScaleBlockK, // scale K
|
||||
index_t MPerBlock,
|
||||
@@ -65,13 +67,16 @@ template <typename ALayout,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v4,
|
||||
typename ComputeTypeA = EDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
bool PermuteA = false,
|
||||
bool PermuteB = false>
|
||||
struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename ComputeTypeA,
|
||||
typename ComputeTypeB,
|
||||
bool PermuteA,
|
||||
bool PermuteB,
|
||||
bool IsBPreShuffled = false,
|
||||
typename AScaleLayout = ALayout,
|
||||
typename BScaleLayout = BLayout>
|
||||
struct GridwiseGemm_wmma_cshuffle_v3_ab_scale
|
||||
: GridwiseGemm_wmma_cshuffle_v3_base<
|
||||
ALayout,
|
||||
BLayout,
|
||||
@@ -123,7 +128,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
ComputeTypeB,
|
||||
PermuteA,
|
||||
PermuteB,
|
||||
false,
|
||||
IsBPreShuffled,
|
||||
true>
|
||||
{
|
||||
using Base = GridwiseGemm_wmma_cshuffle_v3_base<
|
||||
@@ -177,7 +182,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
ComputeTypeB,
|
||||
PermuteA,
|
||||
PermuteB,
|
||||
false,
|
||||
IsBPreShuffled,
|
||||
true>;
|
||||
|
||||
using Base::I0;
|
||||
@@ -233,6 +238,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
std::array<index_t, NumBTensor> StrideBs_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_,
|
||||
index_t StrideScaleA_,
|
||||
index_t StrideScaleB_,
|
||||
index_t KBatch_)
|
||||
: M{M_},
|
||||
@@ -242,6 +248,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
StrideBs{StrideBs_},
|
||||
StrideDs{StrideDs_},
|
||||
StrideE{StrideE_},
|
||||
StrideScaleA{StrideScaleA_},
|
||||
StrideScaleB{StrideScaleB_},
|
||||
KBatch{KBatch_},
|
||||
MPadded{CalculateMPadded(M_)},
|
||||
@@ -251,7 +258,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
AK0{CalculateAK0Padded(K_, KBatch_)},
|
||||
BK0{CalculateBK0Padded(K_, KBatch_)},
|
||||
MBlock{CalculateMBlock(M_)},
|
||||
NBlock{CalculateNBlock(N_)}
|
||||
NBlock{CalculateNBlock(N_)},
|
||||
Kt{K_}
|
||||
{
|
||||
}
|
||||
|
||||
@@ -275,11 +283,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
});
|
||||
std::cout << " }, ";
|
||||
}
|
||||
std::cout << "SE:" << StrideE << ", " << "SScaleB:" << StrideScaleB << ", "
|
||||
<< "MP:" << MPadded << ", " << "NP:" << NPadded << ", " << "KRead:" << KRead
|
||||
<< ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0
|
||||
<< ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}"
|
||||
<< std::endl;
|
||||
std::cout << "SE:" << StrideE << ", " << "SScaleA:" << StrideScaleA << ", "
|
||||
<< "SScaleB:" << StrideScaleB << ", " << "MP:" << MPadded << ", "
|
||||
<< "NP:" << NPadded << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded
|
||||
<< ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", "
|
||||
<< "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl;
|
||||
}
|
||||
|
||||
index_t M;
|
||||
@@ -289,6 +297,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
std::array<index_t, NumBTensor> StrideBs;
|
||||
std::array<index_t, NumDTensor> StrideDs;
|
||||
index_t StrideE;
|
||||
index_t StrideScaleA;
|
||||
index_t StrideScaleB;
|
||||
index_t KBatch;
|
||||
index_t MPadded;
|
||||
@@ -299,6 +308,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
index_t BK0;
|
||||
index_t MBlock;
|
||||
index_t NBlock;
|
||||
index_t Kt;
|
||||
};
|
||||
|
||||
// Argument
|
||||
@@ -315,7 +325,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
std::array<index_t, NumBTensor> StrideBs_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_,
|
||||
index_t StrideScaleA_,
|
||||
index_t StrideScaleB_,
|
||||
const AScaleType* p_a_scale_grid_,
|
||||
const BScaleType* p_b_scale_grid_,
|
||||
index_t k_batch_,
|
||||
AElementwiseOperation a_element_op_,
|
||||
@@ -329,12 +341,14 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
StrideBs_,
|
||||
StrideDs_,
|
||||
StrideE_,
|
||||
StrideScaleA_,
|
||||
StrideScaleB_,
|
||||
k_batch_},
|
||||
p_as_grid{},
|
||||
p_bs_grid{},
|
||||
p_ds_grid{},
|
||||
p_e_grid{p_e_grid_},
|
||||
p_a_scale_grid{p_a_scale_grid_},
|
||||
p_b_scale_grid{p_b_scale_grid_},
|
||||
a_element_op{a_element_op_},
|
||||
b_element_op{b_element_op_},
|
||||
@@ -379,6 +393,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
DsGridPointer p_ds_grid;
|
||||
EDataType* p_e_grid;
|
||||
|
||||
const AScaleType* p_a_scale_grid;
|
||||
const BScaleType* p_b_scale_grid;
|
||||
const AElementwiseOperation a_element_op;
|
||||
const BElementwiseOperation b_element_op;
|
||||
@@ -407,34 +422,52 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
[&](auto i) { a_k_split_offset[i] = k_id * karg.KRead * karg.StrideAs[i]; });
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
if constexpr(IsBPreShuffled)
|
||||
{
|
||||
static_for<0, NumBTensor, 1>{}(
|
||||
[&](auto i) { b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i]; });
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) { b_k_split_offset[i] = 0; });
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
else
|
||||
{
|
||||
if constexpr(!PermuteB)
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
static_for<0, NumBTensor, 1>{}(
|
||||
[&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; });
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i];
|
||||
});
|
||||
}
|
||||
else
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
const int k0_offset = karg.KRead * karg.N;
|
||||
static_for<0, NumBTensor, 1>{}(
|
||||
[&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; });
|
||||
if constexpr(!PermuteB)
|
||||
{
|
||||
static_for<0, NumBTensor, 1>{}(
|
||||
[&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; });
|
||||
}
|
||||
else
|
||||
{
|
||||
const int k0_offset = karg.KRead * karg.N;
|
||||
static_for<0, NumBTensor, 1>{}(
|
||||
[&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate B scale offset
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
// Calculate A scale offset
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, AScaleLayout>)
|
||||
{
|
||||
scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideB;
|
||||
scale_a_k_split_offset = k_id * (karg.KRead / ScaleBlockK);
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, AScaleLayout>)
|
||||
{
|
||||
scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK);
|
||||
scale_a_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideScaleA;
|
||||
}
|
||||
|
||||
// Calculate B scale offset
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BScaleLayout>)
|
||||
{
|
||||
scale_b_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideScaleB;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BScaleLayout>)
|
||||
{
|
||||
scale_b_k_split_offset = k_id * (karg.KRead / ScaleBlockK);
|
||||
}
|
||||
|
||||
if(k_id < karg.KBatch - 1)
|
||||
@@ -458,77 +491,225 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
|
||||
std::array<index_t, NumATensor> a_k_split_offset;
|
||||
std::array<index_t, NumBTensor> b_k_split_offset;
|
||||
index_t scale_k_split_offset; // New member for scale matrix offset
|
||||
index_t scale_a_k_split_offset; // A scale matrix offset
|
||||
index_t scale_b_k_split_offset; // B scale matrix offset
|
||||
index_t c_reduce_offset;
|
||||
};
|
||||
|
||||
using BlockwiseGemmPipe = typename Base::BlockwiseGemmPipe;
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
// if arch = gfx942
|
||||
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
|
||||
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
|
||||
|
||||
template <index_t NumberOfBuffers, typename BScaleGridDesc_BN_AK>
|
||||
__device__ static auto MakeBScale(const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak,
|
||||
const BScaleType* p_b_scale_grid,
|
||||
index_t block_n_id)
|
||||
__device__ static constexpr auto
|
||||
MakeAScaleGridDesciptor_M_K(index_t M, index_t K, index_t StrideScaleA)
|
||||
{
|
||||
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
const auto BM = math::integer_divide_ceil(M, ScaleBlockM);
|
||||
const auto BK = math::integer_divide_ceil(K, ScaleBlockK);
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, AScaleLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(StrideScaleA, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, AScaleLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(I1, StrideScaleA));
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr auto wmma =
|
||||
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>{};
|
||||
static constexpr auto KPerThread = wmma.selected_wmma.k_per_wmma;
|
||||
template <index_t NumberOfBuffers>
|
||||
__device__ static auto
|
||||
MakeAScale(const Problem& problem, const AScaleType* p_a_scale_grid, index_t block_m_id)
|
||||
{
|
||||
if constexpr(ck::is_same_v<AScaleType, void>)
|
||||
{
|
||||
using AScale = typename BlockwiseGemmPipe::Empty;
|
||||
return AScale{};
|
||||
}
|
||||
else
|
||||
{
|
||||
#if defined(__gfx11__)
|
||||
// TODO: remove this restriction
|
||||
static_assert(ScaleBlockM >= MPerWmma,
|
||||
"ScaleBlockM must be greater equal than MPerWmma");
|
||||
#endif
|
||||
static_assert(
|
||||
ScaleBlockK >=
|
||||
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::
|
||||
selected_wmma.k_per_wmma,
|
||||
"ScaleBlockK must be greater equal than KPerWmma");
|
||||
|
||||
static constexpr auto ScaleSliceSizeN = NRepeat;
|
||||
static constexpr auto ScaleSliceSizeK = (KPerThread + ScaleBlockK - 1) / ScaleBlockK;
|
||||
const auto a_scale_grid_desc_am_ak =
|
||||
MakeAScaleGridDesciptor_M_K(problem.M, problem.K, problem.StrideScaleA);
|
||||
|
||||
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}));
|
||||
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
|
||||
|
||||
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
constexpr auto wmma =
|
||||
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>{};
|
||||
constexpr auto RegSizePerWmmaFull =
|
||||
wmma.selected_wmma.num_acc_vgprs_per_wave * wmma.selected_wmma.acc_pack_number;
|
||||
constexpr auto RegSizePerWmma =
|
||||
math::integer_divide_ceil(RegSizePerWmmaFull, ScaleBlockM);
|
||||
|
||||
auto b_thread_offset_n = get_thread_local_1d_id() % NPerWmma +
|
||||
(get_thread_local_1d_id() / 32) % NWaves * NPerWmma;
|
||||
auto b_thread_offset_k = (get_thread_local_1d_id() % 32) / NPerWmma * KPerThread;
|
||||
constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
|
||||
auto b_scale_thread_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2<BScaleType,
|
||||
BScaleType,
|
||||
decltype(b_scale_grid_desc_bn_ak),
|
||||
decltype(b_scale_thread_desc),
|
||||
Sequence<1, ScaleSliceSizeK>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
ScaleSliceSizeK,
|
||||
1,
|
||||
false>(
|
||||
b_scale_grid_desc_bn_ak,
|
||||
make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n,
|
||||
b_thread_offset_k / ScaleBlockK));
|
||||
constexpr auto ScaleSliceSizeM =
|
||||
ScaleBlockM < MPerWmma ? MRepeat * RegSizePerWmma
|
||||
: math::integer_divide_ceil(MPerBlock, ScaleBlockM);
|
||||
constexpr auto ScaleSliceStrideM =
|
||||
math::integer_divide_ceil(MWaves * MPerWmma, ScaleBlockM);
|
||||
constexpr auto ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
|
||||
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeK>{}));
|
||||
|
||||
using BScale =
|
||||
typename BlockwiseGemmPipe::template BScale<ScaleSliceSizeN,
|
||||
ScaleSliceSizeK,
|
||||
NWaves,
|
||||
ScaleBlockK,
|
||||
NumberOfBuffers,
|
||||
decltype(b_scale_grid_desc_bn_ak),
|
||||
decltype(b_scale_thread_copy),
|
||||
decltype(b_scale_grid_buf),
|
||||
decltype(b_scale_thread_buf),
|
||||
decltype(b_scale_thread_desc)>;
|
||||
auto a_thread_offset_m =
|
||||
((get_thread_local_1d_id() % 32) / MPerWmma * RegSizePerWmma) /
|
||||
math::integer_divide_ceil(ScaleBlockM, RegSizePerWmmaFull) +
|
||||
(get_thread_local_1d_id() / 32) / NWaves * MPerWmma / ScaleBlockM;
|
||||
|
||||
return BScale{b_scale_grid_desc_bn_ak, b_scale_thread_copy, b_scale_grid_buf};
|
||||
constexpr index_t VectorDim =
|
||||
is_same<tensor_layout::gemm::ColumnMajor, AScaleLayout>::value ? 0 : 1;
|
||||
constexpr index_t VectorSize =
|
||||
is_same<tensor_layout::gemm::ColumnMajor, AScaleLayout>::value ? RegSizePerWmma
|
||||
: ScaleSliceSizeK;
|
||||
|
||||
auto a_scale_thread_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2<AScaleType,
|
||||
AScaleType,
|
||||
decltype(a_scale_grid_desc_am_ak),
|
||||
decltype(a_scale_thread_desc),
|
||||
Sequence<RegSizePerWmma, ScaleSliceSizeK>,
|
||||
Sequence<0, 1>,
|
||||
VectorDim,
|
||||
VectorSize,
|
||||
1,
|
||||
true>(
|
||||
a_scale_grid_desc_am_ak,
|
||||
make_multi_index(block_m_id * MPerBlock / ScaleBlockM + a_thread_offset_m, 0));
|
||||
|
||||
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AScaleType>(
|
||||
a_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
using AScale =
|
||||
typename BlockwiseGemmPipe::template ABScale<ScaleSliceSizeM,
|
||||
ScaleSliceStrideM,
|
||||
ScaleSliceSizeK,
|
||||
NumberOfBuffers,
|
||||
RegSizePerWmma,
|
||||
decltype(a_scale_grid_desc_am_ak),
|
||||
decltype(a_scale_thread_copy),
|
||||
decltype(a_scale_grid_buf),
|
||||
decltype(a_scale_thread_buf),
|
||||
decltype(a_scale_thread_desc)>;
|
||||
|
||||
return AScale{a_scale_grid_desc_am_ak, a_scale_thread_copy, a_scale_grid_buf};
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static constexpr auto
|
||||
MakeBScaleGridDesciptor_N_K(index_t N, index_t K, index_t StrideScaleB)
|
||||
{
|
||||
const auto BN = math::integer_divide_ceil(N, ScaleBlockN);
|
||||
const auto BK = math::integer_divide_ceil(K, ScaleBlockK);
|
||||
if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BScaleLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(StrideScaleB, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::RowMajor, BScaleLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(I1, StrideScaleB));
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t NumberOfBuffers>
|
||||
__device__ static auto
|
||||
MakeBScale(const Problem& problem, const BScaleType* p_b_scale_grid, index_t block_n_id)
|
||||
{
|
||||
if constexpr(ck::is_same_v<BScaleType, void>)
|
||||
{
|
||||
using BScale = typename BlockwiseGemmPipe::Empty;
|
||||
return BScale{};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(
|
||||
ScaleBlockK >=
|
||||
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::
|
||||
selected_wmma.k_per_wmma,
|
||||
"ScaleBlockK must be greater equal than KPerWmma");
|
||||
|
||||
const auto b_scale_grid_desc_bn_ak =
|
||||
MakeBScaleGridDesciptor_N_K(problem.N, problem.K, problem.StrideScaleB);
|
||||
|
||||
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
|
||||
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
|
||||
constexpr auto ScaleSliceSizeN =
|
||||
ScaleBlockN < NPerWmma ? NRepeat
|
||||
: math::integer_divide_ceil(NPerBlock, ScaleBlockN);
|
||||
constexpr auto ScaleSliceStrideN =
|
||||
math::integer_divide_ceil(NWaves * NPerWmma, ScaleBlockN);
|
||||
constexpr auto ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
|
||||
|
||||
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}));
|
||||
|
||||
auto b_thread_offset_n = (get_thread_local_1d_id() % NPerWmma +
|
||||
(get_thread_local_1d_id() / 32) % NWaves * NPerWmma) /
|
||||
ScaleBlockN;
|
||||
|
||||
constexpr index_t VectorDim =
|
||||
is_same<tensor_layout::gemm::RowMajor, BScaleLayout>::value ? 0 : 1;
|
||||
constexpr index_t VectorSize =
|
||||
is_same<tensor_layout::gemm::RowMajor, BScaleLayout>::value ? 1 : ScaleSliceSizeK;
|
||||
|
||||
auto b_scale_thread_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2<BScaleType,
|
||||
BScaleType,
|
||||
decltype(b_scale_grid_desc_bn_ak),
|
||||
decltype(b_scale_thread_desc),
|
||||
Sequence<1, ScaleSliceSizeK>,
|
||||
Sequence<0, 1>,
|
||||
VectorDim,
|
||||
VectorSize,
|
||||
1,
|
||||
true>(
|
||||
b_scale_grid_desc_bn_ak,
|
||||
make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n, 0));
|
||||
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BScaleType>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
using BScale =
|
||||
typename BlockwiseGemmPipe::template ABScale<ScaleSliceSizeN,
|
||||
ScaleSliceStrideN,
|
||||
ScaleSliceSizeK,
|
||||
NumberOfBuffers,
|
||||
1,
|
||||
decltype(b_scale_grid_desc_bn_ak),
|
||||
decltype(b_scale_thread_copy),
|
||||
decltype(b_scale_grid_buf),
|
||||
decltype(b_scale_thread_buf),
|
||||
decltype(b_scale_thread_desc)>;
|
||||
|
||||
return BScale{b_scale_grid_desc_bn_ak, b_scale_thread_copy, b_scale_grid_buf};
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static index_t GetKBlockPerScale()
|
||||
{
|
||||
return (ScaleBlockK + KPerBlock - 1) / KPerBlock;
|
||||
if constexpr(ck::is_same_v<AScaleType, void> && ck::is_same_v<BScaleType, void>)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
return (ScaleBlockK + KPerBlock - 1) / KPerBlock;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
@@ -539,18 +720,21 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
BsGridPointer& p_bs_grid,
|
||||
DsGridPointer& p_ds_grid,
|
||||
EDataType* p_e_grid,
|
||||
const AScaleType* p_a_scale_grid,
|
||||
const BScaleType* p_b_scale_grid,
|
||||
void* p_shared,
|
||||
const Problem& problem,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
EpilogueArgument& epilogue_args)
|
||||
EpilogueArgument& epilogue_args,
|
||||
const index_t k_id = 0)
|
||||
{
|
||||
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
|
||||
const index_t K_b = IsBPreShuffled ? problem.Kt : problem.K;
|
||||
const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
|
||||
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
|
||||
K_b, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
|
||||
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
|
||||
const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
|
||||
@@ -562,12 +746,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
|
||||
// B Scale grid
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockK)),
|
||||
make_tuple(problem.StrideScaleB, 1));
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
|
||||
|
||||
@@ -585,8 +763,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
|
||||
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
|
||||
|
||||
// AScale struct
|
||||
auto a_scale_struct = MakeAScale<1>(problem, p_a_scale_grid, block_m_id);
|
||||
|
||||
// BScale struct
|
||||
auto b_scale_struct = MakeBScale<1>(b_scale_grid_desc_bn_ak, p_b_scale_grid, block_n_id);
|
||||
auto b_scale_struct = MakeBScale<1>(problem, p_b_scale_grid, block_n_id);
|
||||
|
||||
const index_t num_k_block_per_scale = GetKBlockPerScale();
|
||||
|
||||
@@ -594,6 +775,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
decltype(bs_grid_desc_bk0_n_bk1),
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(a_scale_struct),
|
||||
decltype(b_scale_struct),
|
||||
decltype(epilogue_args),
|
||||
HasMainKBlockLoop,
|
||||
@@ -613,8 +795,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
block_m_id,
|
||||
block_n_id,
|
||||
num_k_block_per_scale,
|
||||
a_scale_struct,
|
||||
b_scale_struct,
|
||||
epilogue_args);
|
||||
epilogue_args,
|
||||
k_id);
|
||||
}
|
||||
|
||||
// NOTE: Wrapper function to have __global__ function in common
|
||||
@@ -626,7 +810,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
__device__ static void Run(void* p_shared,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
Argument& karg,
|
||||
EpilogueArgument& epilogue_args)
|
||||
EpilogueArgument& epilogue_args,
|
||||
const index_t k_id = 0)
|
||||
{
|
||||
// shift A matrices pointer for splitk
|
||||
AsGridPointer p_as_grid_splitk;
|
||||
@@ -644,18 +829,40 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
splitk_batch_offset.b_k_split_offset[i];
|
||||
});
|
||||
|
||||
const AScaleType* p_a_scale_grid_ptr;
|
||||
if constexpr(ck::is_same_v<AScaleType, void>)
|
||||
{
|
||||
p_a_scale_grid_ptr = karg.p_a_scale_grid;
|
||||
}
|
||||
else
|
||||
{
|
||||
p_a_scale_grid_ptr = karg.p_a_scale_grid + splitk_batch_offset.scale_a_k_split_offset;
|
||||
}
|
||||
|
||||
const BScaleType* p_b_scale_grid_ptr;
|
||||
if constexpr(ck::is_same_v<BScaleType, void>)
|
||||
{
|
||||
p_b_scale_grid_ptr = karg.p_b_scale_grid;
|
||||
}
|
||||
else
|
||||
{
|
||||
p_b_scale_grid_ptr = karg.p_b_scale_grid + splitk_batch_offset.scale_b_k_split_offset;
|
||||
}
|
||||
|
||||
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
p_as_grid_splitk,
|
||||
p_bs_grid_splitk,
|
||||
karg.p_ds_grid,
|
||||
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
|
||||
karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset,
|
||||
p_a_scale_grid_ptr,
|
||||
p_b_scale_grid_ptr,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op,
|
||||
epilogue_args);
|
||||
epilogue_args,
|
||||
k_id);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -69,6 +69,48 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
TailNumber TailNum = TailNumber::Full>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_gemm_b_preshuffle_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if(defined(__gfx11__) || defined(__gfx12__))
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
|
||||
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
|
||||
(std::is_same_v<e_data_type, ck::half_t> ||
|
||||
std::is_same_v<e_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
const index_t num_k_per_block = math::integer_divide_ceil(karg.K, GridwiseGemm::KPack);
|
||||
const index_t k_id = blockIdx.z * num_k_per_block;
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
p_shared, splitk_batch_offset, karg, epilogue_args, k_id);
|
||||
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
@@ -162,7 +204,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
|
||||
static constexpr index_t KInnerB = ck::math::integer_divide_ceil(BK1Value, KPerWmmaBlk);
|
||||
|
||||
static constexpr index_t KInner = ck::math::min(KInnerA, KInnerB);
|
||||
static constexpr index_t KInner = IsBPreShuffled ? KInnerB : ck::math::min(KInnerA, KInnerB);
|
||||
|
||||
static constexpr index_t KPack =
|
||||
KInner *
|
||||
@@ -966,6 +1008,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename AScaleStruct,
|
||||
typename BScaleStruct,
|
||||
typename EpilogueArgument,
|
||||
bool HasMainKBlockLoop,
|
||||
@@ -988,6 +1031,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
const index_t& block_m_id,
|
||||
const index_t& block_n_id,
|
||||
const index_t& num_k_block_per_scale,
|
||||
AScaleStruct& a_scale_struct,
|
||||
BScaleStruct& b_scale_struct,
|
||||
EpilogueArgument& epilogue_args,
|
||||
const index_t k_id = 0)
|
||||
@@ -1072,6 +1116,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
c_thread_buf,
|
||||
a_scale_struct,
|
||||
b_scale_struct,
|
||||
num_k_block_main_loop,
|
||||
num_k_block_per_scale);
|
||||
|
||||
@@ -43,13 +43,15 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
{
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid,
|
||||
karg.p_b_grid,
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
karg.p_a_scale_grid,
|
||||
karg.p_b_scale_grid,
|
||||
karg.p_a_scale_grid + splitk_batch_offset.scale_a_k_split_offset,
|
||||
karg.p_b_scale_grid + splitk_batch_offset.scale_b_k_split_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
@@ -405,31 +407,33 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeAScaleGridDesciptor_M_K(index_t M, index_t K)
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAScaleGridDesciptor_M_K(index_t M, index_t K, index_t StrideScaleA)
|
||||
{
|
||||
const auto BM = math::integer_divide_ceil(M, ScaleBlockM);
|
||||
const auto BK = math::integer_divide_ceil(K, ScaleBlockK);
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(BK, I1));
|
||||
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(StrideScaleA, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(I1, BM));
|
||||
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(I1, StrideScaleA));
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeBScaleGridDesciptor_N_K(index_t N, index_t K)
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBScaleGridDesciptor_N_K(index_t N, index_t K, index_t StrideScaleB)
|
||||
{
|
||||
const auto BN = math::integer_divide_ceil(N, ScaleBlockN);
|
||||
const auto BK = math::integer_divide_ceil(K, ScaleBlockK);
|
||||
if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(BK, I1));
|
||||
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(StrideScaleB, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(I1, BN));
|
||||
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(I1, StrideScaleB));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -548,6 +552,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
index_t StrideB_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideC_,
|
||||
index_t StrideScaleA_,
|
||||
index_t StrideScaleB_,
|
||||
index_t KBatch_)
|
||||
: M{M_},
|
||||
N{N_},
|
||||
@@ -556,6 +562,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
StrideB{StrideB_},
|
||||
StrideDs{StrideDs_},
|
||||
StrideC{StrideC_},
|
||||
StrideScaleA{StrideScaleA_},
|
||||
StrideScaleB{StrideScaleB_},
|
||||
KBatch{KBatch_},
|
||||
MPadded{CalculateMPadded(M_)},
|
||||
NPadded{CalculateNPadded(N_)},
|
||||
@@ -585,7 +593,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
index_t StrideB;
|
||||
std::array<index_t, NumDTensor> StrideDs;
|
||||
index_t StrideC;
|
||||
|
||||
index_t StrideScaleA;
|
||||
index_t StrideScaleB;
|
||||
index_t KBatch;
|
||||
index_t MPadded;
|
||||
index_t NPadded;
|
||||
@@ -611,13 +620,24 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
index_t StrideB_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideC_,
|
||||
index_t StrideScaleA_,
|
||||
index_t StrideScaleB_,
|
||||
const AScaleType* p_a_scale_grid_,
|
||||
const BScaleType* p_b_scale_grid_,
|
||||
index_t k_batch_,
|
||||
AElementwiseOperation a_element_op_,
|
||||
BElementwiseOperation b_element_op_,
|
||||
CElementwiseOperation c_element_op_)
|
||||
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_},
|
||||
: Problem{M_,
|
||||
N_,
|
||||
K_,
|
||||
StrideA_,
|
||||
StrideB_,
|
||||
StrideDs_,
|
||||
StrideC_,
|
||||
StrideScaleA_,
|
||||
StrideScaleB_,
|
||||
k_batch_},
|
||||
p_a_grid{p_a_grid_},
|
||||
p_b_grid{p_b_grid_},
|
||||
p_ds_grid{},
|
||||
@@ -673,6 +693,28 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
b_k_split_offset = blockIdx.z * karg.KRead;
|
||||
}
|
||||
|
||||
// Calculate A scale offset
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
scale_a_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK);
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
scale_a_k_split_offset =
|
||||
blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideScaleA;
|
||||
}
|
||||
|
||||
// Calculate B scale offset
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
scale_b_k_split_offset =
|
||||
blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideScaleB;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
scale_b_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK);
|
||||
}
|
||||
|
||||
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
|
||||
{
|
||||
karg.K = karg.KRead;
|
||||
@@ -685,6 +727,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
|
||||
index_t a_k_split_offset;
|
||||
index_t b_k_split_offset;
|
||||
index_t scale_a_k_split_offset; // A scale matrix offset
|
||||
index_t scale_b_k_split_offset; // B scale matrix offset
|
||||
};
|
||||
|
||||
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
@@ -1221,8 +1265,10 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
|
||||
|
||||
const auto a_scale_grid_desc_am_ak = MakeAScaleGridDesciptor_M_K(problem.M, problem.K);
|
||||
const auto b_scale_grid_desc_bn_ak = MakeBScaleGridDesciptor_N_K(problem.N, problem.K);
|
||||
const auto a_scale_grid_desc_am_ak =
|
||||
MakeAScaleGridDesciptor_M_K(problem.M, problem.K, problem.StrideScaleA);
|
||||
const auto b_scale_grid_desc_bn_ak =
|
||||
MakeBScaleGridDesciptor_N_K(problem.N, problem.K, problem.StrideScaleB);
|
||||
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
|
||||
@@ -380,236 +380,143 @@ struct sequence_reduce<Reduce, Seq>
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename Values, typename Ids, typename Compare>
|
||||
struct sequence_sort_impl
|
||||
// Implement sequence_sort and sequence_unique_sort using constexpr functions (C++17)
|
||||
namespace sort_impl {
|
||||
|
||||
// Temporary arrays to hold values during operations with capacity N and mutable size.
|
||||
template <index_t N>
|
||||
struct IndexedValueArray
|
||||
{
|
||||
template <typename LeftValues,
|
||||
typename LeftIds,
|
||||
typename RightValues,
|
||||
typename RightIds,
|
||||
typename MergedValues,
|
||||
typename MergedIds,
|
||||
typename Comp>
|
||||
struct sorted_sequence_merge_impl
|
||||
{
|
||||
static constexpr bool choose_left = LeftValues::Front() < RightValues::Front();
|
||||
|
||||
static constexpr index_t chosen_value =
|
||||
choose_left ? LeftValues::Front() : RightValues::Front();
|
||||
static constexpr index_t chosen_id = choose_left ? LeftIds::Front() : RightIds::Front();
|
||||
|
||||
using new_merged_values = decltype(MergedValues::PushBack(Number<chosen_value>{}));
|
||||
using new_merged_ids = decltype(MergedIds::PushBack(Number<chosen_id>{}));
|
||||
|
||||
using new_left_values =
|
||||
typename conditional<choose_left, decltype(LeftValues::PopFront()), LeftValues>::type;
|
||||
using new_left_ids =
|
||||
typename conditional<choose_left, decltype(LeftIds::PopFront()), LeftIds>::type;
|
||||
|
||||
using new_right_values =
|
||||
typename conditional<choose_left, RightValues, decltype(RightValues::PopFront())>::type;
|
||||
using new_right_ids =
|
||||
typename conditional<choose_left, RightIds, decltype(RightIds::PopFront())>::type;
|
||||
|
||||
using merge = sorted_sequence_merge_impl<new_left_values,
|
||||
new_left_ids,
|
||||
new_right_values,
|
||||
new_right_ids,
|
||||
new_merged_values,
|
||||
new_merged_ids,
|
||||
Comp>;
|
||||
// this is output
|
||||
using merged_values = typename merge::merged_values;
|
||||
using merged_ids = typename merge::merged_ids;
|
||||
};
|
||||
|
||||
template <typename LeftValues,
|
||||
typename LeftIds,
|
||||
typename MergedValues,
|
||||
typename MergedIds,
|
||||
typename Comp>
|
||||
struct sorted_sequence_merge_impl<LeftValues,
|
||||
LeftIds,
|
||||
Sequence<>,
|
||||
Sequence<>,
|
||||
MergedValues,
|
||||
MergedIds,
|
||||
Comp>
|
||||
{
|
||||
using merged_values = typename sequence_merge<MergedValues, LeftValues>::type;
|
||||
using merged_ids = typename sequence_merge<MergedIds, LeftIds>::type;
|
||||
};
|
||||
|
||||
template <typename RightValues,
|
||||
typename RightIds,
|
||||
typename MergedValues,
|
||||
typename MergedIds,
|
||||
typename Comp>
|
||||
struct sorted_sequence_merge_impl<Sequence<>,
|
||||
Sequence<>,
|
||||
RightValues,
|
||||
RightIds,
|
||||
MergedValues,
|
||||
MergedIds,
|
||||
Comp>
|
||||
{
|
||||
using merged_values = typename sequence_merge<MergedValues, RightValues>::type;
|
||||
using merged_ids = typename sequence_merge<MergedIds, RightIds>::type;
|
||||
};
|
||||
|
||||
template <typename LeftValues,
|
||||
typename LeftIds,
|
||||
typename RightValues,
|
||||
typename RightIds,
|
||||
typename Comp>
|
||||
struct sorted_sequence_merge
|
||||
{
|
||||
using merge = sorted_sequence_merge_impl<LeftValues,
|
||||
LeftIds,
|
||||
RightValues,
|
||||
RightIds,
|
||||
Sequence<>,
|
||||
Sequence<>,
|
||||
Comp>;
|
||||
|
||||
using merged_values = typename merge::merged_values;
|
||||
using merged_ids = typename merge::merged_ids;
|
||||
};
|
||||
|
||||
static constexpr index_t nsize = Values::Size();
|
||||
|
||||
using split_unsorted_values = sequence_split<Values, nsize / 2>;
|
||||
using split_unsorted_ids = sequence_split<Ids, nsize / 2>;
|
||||
|
||||
using left_unsorted_values = typename split_unsorted_values::left_type;
|
||||
using left_unsorted_ids = typename split_unsorted_ids::left_type;
|
||||
using left_sort = sequence_sort_impl<left_unsorted_values, left_unsorted_ids, Compare>;
|
||||
using left_sorted_values = typename left_sort::sorted_values;
|
||||
using left_sorted_ids = typename left_sort::sorted_ids;
|
||||
|
||||
using right_unsorted_values = typename split_unsorted_values::right_type;
|
||||
using right_unsorted_ids = typename split_unsorted_ids::right_type;
|
||||
using right_sort = sequence_sort_impl<right_unsorted_values, right_unsorted_ids, Compare>;
|
||||
using right_sorted_values = typename right_sort::sorted_values;
|
||||
using right_sorted_ids = typename right_sort::sorted_ids;
|
||||
|
||||
using merged_sorted = sorted_sequence_merge<left_sorted_values,
|
||||
left_sorted_ids,
|
||||
right_sorted_values,
|
||||
right_sorted_ids,
|
||||
Compare>;
|
||||
|
||||
using sorted_values = typename merged_sorted::merged_values;
|
||||
using sorted_ids = typename merged_sorted::merged_ids;
|
||||
index_t values[N > 0 ? N : 1];
|
||||
index_t ids[N > 0 ? N : 1];
|
||||
index_t size = 0;
|
||||
};
|
||||
|
||||
template <index_t ValueX, index_t ValueY, index_t IdX, index_t IdY, typename Compare>
|
||||
struct sequence_sort_impl<Sequence<ValueX, ValueY>, Sequence<IdX, IdY>, Compare>
|
||||
template <index_t... Is>
|
||||
constexpr auto make_indexed_value_array(Sequence<Is...>)
|
||||
{
|
||||
static constexpr bool choose_x = Compare{}(ValueX, ValueY);
|
||||
constexpr index_t N = sizeof...(Is);
|
||||
IndexedValueArray<N> result = {{Is...}, {}, N};
|
||||
for(index_t i = 0; i < N; ++i)
|
||||
{
|
||||
result.ids[i] = i;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
using sorted_values =
|
||||
typename conditional<choose_x, Sequence<ValueX, ValueY>, Sequence<ValueY, ValueX>>::type;
|
||||
using sorted_ids = typename conditional<choose_x, Sequence<IdX, IdY>, Sequence<IdY, IdX>>::type;
|
||||
enum class SortField
|
||||
{
|
||||
Values,
|
||||
Ids
|
||||
};
|
||||
|
||||
template <index_t Value, index_t Id, typename Compare>
|
||||
struct sequence_sort_impl<Sequence<Value>, Sequence<Id>, Compare>
|
||||
// Perform an insertion sort on an IndexedValueArray.
|
||||
template <index_t N, typename Compare>
|
||||
constexpr auto insertion_sort(IndexedValueArray<N> arr, Compare comp)
|
||||
{
|
||||
using sorted_values = Sequence<Value>;
|
||||
using sorted_ids = Sequence<Id>;
|
||||
for(index_t i = 1; i < arr.size; ++i)
|
||||
{
|
||||
index_t key_val = arr.values[i];
|
||||
index_t key_id = arr.ids[i];
|
||||
index_t j = i - 1;
|
||||
while(j >= 0 && comp(key_val, arr.values[j]))
|
||||
{
|
||||
arr.values[j + 1] = arr.values[j];
|
||||
arr.ids[j + 1] = arr.ids[j];
|
||||
--j;
|
||||
}
|
||||
arr.values[j + 1] = key_val;
|
||||
arr.ids[j + 1] = key_id;
|
||||
}
|
||||
return arr;
|
||||
}
|
||||
|
||||
// Remove duplicates from a sorted IndexedValueArray.
|
||||
template <index_t N, typename Equal>
|
||||
constexpr auto unique(const IndexedValueArray<N>& sorted, Equal eq)
|
||||
{
|
||||
IndexedValueArray<N> result{};
|
||||
if constexpr(N == 0)
|
||||
{
|
||||
return result;
|
||||
}
|
||||
result.size = 1;
|
||||
result.values[0] = sorted.values[0];
|
||||
result.ids[0] = sorted.ids[0];
|
||||
for(index_t i = 1; i < sorted.size; ++i)
|
||||
{
|
||||
if(!eq(sorted.values[i], sorted.values[i - 1]))
|
||||
{
|
||||
result.values[result.size] = sorted.values[i];
|
||||
result.ids[result.size] = sorted.ids[i];
|
||||
++result.size;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Compute sorted (and optionally unique) IndexedValueArray from input Sequence.
|
||||
template <bool Unique, typename Compare, typename Equal, index_t... Is>
|
||||
constexpr auto compute_sorted(Sequence<Is...> seq, Compare comp, Equal eq)
|
||||
{
|
||||
auto sorted = insertion_sort(make_indexed_value_array(seq), comp);
|
||||
return Unique ? unique(sorted, eq) : sorted;
|
||||
}
|
||||
|
||||
// Cache the sorted results to avoid recomputation.
|
||||
template <bool Unique, typename Seq, typename Compare, typename Equal>
|
||||
struct SortedCache
|
||||
{
|
||||
static constexpr auto data = compute_sorted<Unique>(Seq{}, Compare{}, Equal{});
|
||||
};
|
||||
|
||||
template <typename Compare>
|
||||
struct sequence_sort_impl<Sequence<>, Sequence<>, Compare>
|
||||
// Build sorted value and ID sequences from cached sorted data
|
||||
template <SortField Field, bool Unique, typename Seq, typename Compare, typename Equal, index_t I>
|
||||
constexpr index_t get_sorted_field()
|
||||
{
|
||||
using sorted_values = Sequence<>;
|
||||
using sorted_ids = Sequence<>;
|
||||
constexpr auto& data = SortedCache<Unique, Seq, Compare, Equal>::data;
|
||||
return (Field == SortField::Values) ? data.values[I] : data.ids[I];
|
||||
}
|
||||
|
||||
template <bool Unique, typename Seq, typename Compare, typename Equal, typename IndexSeq>
|
||||
struct SortedSequences;
|
||||
|
||||
template <bool Unique, typename Seq, typename Compare, typename Equal, index_t... Is>
|
||||
struct SortedSequences<Unique, Seq, Compare, Equal, Sequence<Is...>>
|
||||
{
|
||||
using values_type =
|
||||
Sequence<get_sorted_field<SortField::Values, Unique, Seq, Compare, Equal, Is>()...>;
|
||||
using ids_type =
|
||||
Sequence<get_sorted_field<SortField::Ids, Unique, Seq, Compare, Equal, Is>()...>;
|
||||
};
|
||||
|
||||
template <bool Unique, typename Seq, typename Compare, typename Equal>
|
||||
using sorted_sequences_t = SortedSequences<
|
||||
Unique,
|
||||
Seq,
|
||||
Compare,
|
||||
Equal,
|
||||
typename arithmetic_sequence_gen<0, SortedCache<Unique, Seq, Compare, Equal>::data.size, 1>::
|
||||
type>;
|
||||
|
||||
using Equal = ck::math::equal<index_t>;
|
||||
|
||||
} // namespace sort_impl
|
||||
|
||||
template <typename Values, typename Compare>
|
||||
struct sequence_sort
|
||||
{
|
||||
using unsorted_ids = typename arithmetic_sequence_gen<0, Values::Size(), 1>::type;
|
||||
using sort = sequence_sort_impl<Values, unsorted_ids, Compare>;
|
||||
|
||||
// this is output
|
||||
using type = typename sort::sorted_values;
|
||||
using sorted2unsorted_map = typename sort::sorted_ids;
|
||||
using sorted_seqs = sort_impl::sorted_sequences_t<false, Values, Compare, sort_impl::Equal>;
|
||||
using type = typename sorted_seqs::values_type;
|
||||
using sorted2unsorted_map = typename sorted_seqs::ids_type;
|
||||
};
|
||||
|
||||
template <typename Values, typename Less, typename Equal>
|
||||
struct sequence_unique_sort
|
||||
{
|
||||
template <typename RemainValues,
|
||||
typename RemainIds,
|
||||
typename UniquifiedValues,
|
||||
typename UniquifiedIds,
|
||||
typename Eq>
|
||||
struct sorted_sequence_uniquify_impl
|
||||
{
|
||||
static constexpr index_t current_value = RemainValues::Front();
|
||||
static constexpr index_t current_id = RemainIds::Front();
|
||||
|
||||
static constexpr bool is_unique_value = (current_value != UniquifiedValues::Back());
|
||||
|
||||
using new_remain_values = decltype(RemainValues::PopFront());
|
||||
using new_remain_ids = decltype(RemainIds::PopFront());
|
||||
|
||||
using new_uniquified_values =
|
||||
typename conditional<is_unique_value,
|
||||
decltype(UniquifiedValues::PushBack(Number<current_value>{})),
|
||||
UniquifiedValues>::type;
|
||||
|
||||
using new_uniquified_ids =
|
||||
typename conditional<is_unique_value,
|
||||
decltype(UniquifiedIds::PushBack(Number<current_id>{})),
|
||||
UniquifiedIds>::type;
|
||||
|
||||
using uniquify = sorted_sequence_uniquify_impl<new_remain_values,
|
||||
new_remain_ids,
|
||||
new_uniquified_values,
|
||||
new_uniquified_ids,
|
||||
Eq>;
|
||||
|
||||
// this is output
|
||||
using uniquified_values = typename uniquify::uniquified_values;
|
||||
using uniquified_ids = typename uniquify::uniquified_ids;
|
||||
};
|
||||
|
||||
template <typename UniquifiedValues, typename UniquifiedIds, typename Eq>
|
||||
struct sorted_sequence_uniquify_impl<Sequence<>,
|
||||
Sequence<>,
|
||||
UniquifiedValues,
|
||||
UniquifiedIds,
|
||||
Eq>
|
||||
{
|
||||
using uniquified_values = UniquifiedValues;
|
||||
using uniquified_ids = UniquifiedIds;
|
||||
};
|
||||
|
||||
template <typename SortedValues, typename SortedIds, typename Eq>
|
||||
struct sorted_sequence_uniquify
|
||||
{
|
||||
using uniquify = sorted_sequence_uniquify_impl<decltype(SortedValues::PopFront()),
|
||||
decltype(SortedIds::PopFront()),
|
||||
Sequence<SortedValues::Front()>,
|
||||
Sequence<SortedIds::Front()>,
|
||||
Eq>;
|
||||
|
||||
using uniquified_values = typename uniquify::uniquified_values;
|
||||
using uniquified_ids = typename uniquify::uniquified_ids;
|
||||
};
|
||||
|
||||
using sort = sequence_sort<Values, Less>;
|
||||
using sorted_values = typename sort::type;
|
||||
using sorted_ids = typename sort::sorted2unsorted_map;
|
||||
|
||||
using uniquify = sorted_sequence_uniquify<sorted_values, sorted_ids, Equal>;
|
||||
|
||||
// this is output
|
||||
using type = typename uniquify::uniquified_values;
|
||||
using sorted2unsorted_map = typename uniquify::uniquified_ids;
|
||||
using sorted_seqs = sort_impl::sorted_sequences_t<true, Values, Less, Equal>;
|
||||
using type = typename sorted_seqs::values_type;
|
||||
using sorted2unsorted_map = typename sorted_seqs::ids_type;
|
||||
};
|
||||
|
||||
template <typename SeqMap>
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/algorithm/cluster_descriptor.hpp"
|
||||
|
||||
@@ -1550,9 +1550,10 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, e8m0_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, pk_int4_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) ||
|
||||
(std::is_same<T, pk_fp4_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16))),
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) ||
|
||||
(std::is_same<T, pk_fp4_raw_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, pk_fp4_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
|
||||
"wrong! not implemented");
|
||||
|
||||
using rtn_type = thread_buffer<T, N>;
|
||||
|
||||
@@ -39,8 +39,12 @@
|
||||
#define CK_TILE_DEVICE inline __device__
|
||||
#define CK_TILE_HOST_DEVICE inline __host__ __device__
|
||||
#define CK_TILE_DEVICE_EXTERN __device__
|
||||
#if __clang_major__ < 22
|
||||
#define CK_TILE_HOST_DEVICE_EXTERN __host__ __device__
|
||||
#else
|
||||
#define CK_TILE_HOST_DEVICE_EXTERN
|
||||
#endif
|
||||
#else
|
||||
#define CK_TILE_HOST inline
|
||||
#define CK_TILE_DEVICE inline
|
||||
#define CK_TILE_HOST_DEVICE inline
|
||||
|
||||
@@ -41,9 +41,8 @@ struct scales
|
||||
Scale lhs_;
|
||||
};
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
template <typename Scale>
|
||||
__host__ __device__ scales(Scale) -> scales<Scale>;
|
||||
CK_TILE_HOST_DEVICE_EXTERN scales(Scale) -> scales<Scale>;
|
||||
|
||||
template <typename Left = void, typename Right = Left>
|
||||
struct plus
|
||||
@@ -66,8 +65,7 @@ struct plus<void, void>
|
||||
}
|
||||
};
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
__host__ __device__ plus() -> plus<void, void>;
|
||||
CK_TILE_HOST_DEVICE_EXTERN plus() -> plus<void, void>;
|
||||
|
||||
template <typename Left = void, typename Right = Left>
|
||||
struct minus
|
||||
@@ -90,8 +88,7 @@ struct minus<void, void>
|
||||
}
|
||||
};
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
__host__ __device__ minus() -> minus<void, void>;
|
||||
CK_TILE_HOST_DEVICE_EXTERN minus() -> minus<void, void>;
|
||||
|
||||
template <typename Left = void, typename Right = Left>
|
||||
struct multiplies
|
||||
@@ -114,8 +111,7 @@ struct multiplies<void, void>
|
||||
}
|
||||
};
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
__host__ __device__ multiplies() -> multiplies<void, void>;
|
||||
CK_TILE_HOST_DEVICE_EXTERN multiplies() -> multiplies<void, void>;
|
||||
|
||||
template <typename T>
|
||||
struct maximize
|
||||
@@ -345,8 +341,7 @@ struct equal<void, void>
|
||||
}
|
||||
};
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
__host__ __device__ equal() -> equal<void, void>;
|
||||
CK_TILE_HOST_DEVICE_EXTERN equal() -> equal<void, void>;
|
||||
|
||||
template <>
|
||||
struct equal<float, float>
|
||||
@@ -387,8 +382,7 @@ struct less<void, void>
|
||||
}
|
||||
};
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
__host__ __device__ less() -> less<void, void>;
|
||||
CK_TILE_HOST_DEVICE_EXTERN less() -> less<void, void>;
|
||||
|
||||
template <typename Left = void, typename Right = Left>
|
||||
struct less_equal
|
||||
@@ -411,8 +405,7 @@ struct less_equal<void, void>
|
||||
}
|
||||
};
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
__host__ __device__ less_equal() -> less_equal<void, void>;
|
||||
CK_TILE_HOST_DEVICE_EXTERN less_equal() -> less_equal<void, void>;
|
||||
|
||||
template <>
|
||||
struct less_equal<float, float>
|
||||
|
||||
@@ -47,9 +47,8 @@ struct composes<F>
|
||||
F f_;
|
||||
};
|
||||
|
||||
/// FIXME: create macro to replace '__host__ __device__' and nothing more
|
||||
template <typename... Ts>
|
||||
__host__ __device__ composes(Ts&&...) -> composes<remove_cvref_t<Ts>...>;
|
||||
CK_TILE_HOST_DEVICE_EXTERN composes(Ts&&...) -> composes<remove_cvref_t<Ts>...>;
|
||||
|
||||
template <typename SaturateType>
|
||||
struct saturates
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/host/arg_parser.hpp"
|
||||
|
||||
@@ -52,9 +52,19 @@ template <typename ComputeDataType, typename OutDataType, typename AccDataType =
|
||||
CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations = 1)
|
||||
{
|
||||
|
||||
static_assert(
|
||||
is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
|
||||
"Warning: Unhandled ComputeDataType for setting up the relative threshold!");
|
||||
static_assert(is_any_of<ComputeDataType,
|
||||
F8,
|
||||
BF8,
|
||||
F16,
|
||||
BF16,
|
||||
F32,
|
||||
pk_fp4_t,
|
||||
pk_fp4_raw_t,
|
||||
pk_int4_t,
|
||||
I8,
|
||||
I32,
|
||||
int>::value,
|
||||
"Warning: Unhandled ComputeDataType for setting up the relative threshold!");
|
||||
|
||||
double compute_error = 0;
|
||||
if constexpr(is_any_of<ComputeDataType, pk_int4_t, I8, I32, int>::value)
|
||||
@@ -113,9 +123,19 @@ CK_TILE_HOST double get_absolute_threshold(const double max_possible_num,
|
||||
const int number_of_accumulations = 1)
|
||||
{
|
||||
|
||||
static_assert(
|
||||
is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
|
||||
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
|
||||
static_assert(is_any_of<ComputeDataType,
|
||||
F8,
|
||||
BF8,
|
||||
F16,
|
||||
BF16,
|
||||
F32,
|
||||
pk_fp4_t,
|
||||
pk_fp4_raw_t,
|
||||
pk_int4_t,
|
||||
I8,
|
||||
I32,
|
||||
int>::value,
|
||||
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
|
||||
|
||||
auto expo = std::log2(std::abs(max_possible_num));
|
||||
double compute_error = 0;
|
||||
|
||||
@@ -372,6 +372,63 @@ CK_TILE_HOST void reference_gemm_tensor_quant(const HostTensor<ADataType>& a_m_k
|
||||
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename QDataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename QuantGroupSize,
|
||||
bool aquant,
|
||||
typename AElementOp = ck_tile::identity,
|
||||
typename BElementOp = ck_tile::identity,
|
||||
typename ACCElementOp = ck_tile::identity>
|
||||
CK_TILE_HOST void reference_mxfp4gemm_quant(const HostTensor<ADataType>& a_m_k,
|
||||
const HostTensor<QDataType>& q,
|
||||
const HostTensor<BDataType>& b_k_n,
|
||||
HostTensor<CDataType>& c_m_n,
|
||||
const AElementOp& a_element_op = {},
|
||||
const BElementOp& b_element_op = {},
|
||||
const ACCElementOp& acc_element_op = {})
|
||||
{
|
||||
const std::size_t M = a_m_k.get_length(0);
|
||||
const std::size_t N = b_k_n.get_length(1);
|
||||
const std::size_t K = a_m_k.get_length(1);
|
||||
|
||||
auto f_mn = [&](auto m, auto n) {
|
||||
AccDataType v_acc = 0;
|
||||
AccDataType pasual = 0;
|
||||
for(std::size_t k = 0; k < (K / 2); k++)
|
||||
{
|
||||
using ComputeType = float;
|
||||
auto b_scale = type_convert<int32_t>(q((2 * k) / QuantGroupSize::kK, n)) - 127;
|
||||
ComputeType v_a_0, v_a_1;
|
||||
ComputeType v_b_0, v_b_1;
|
||||
|
||||
v_a_0 = ck_tile::type_convert<ComputeType>((a_element_op(a_m_k(m, 2 * k))));
|
||||
v_a_1 = ck_tile::type_convert<ComputeType>((a_element_op(a_m_k(m, 2 * k + 1))));
|
||||
|
||||
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
|
||||
{
|
||||
auto b_pack = type_convert<pk_fp4_t>(b_element_op(b_k_n(k, n)));
|
||||
auto b_scale_fp4 = type_convert<float>(std::pow(2.0f, b_scale));
|
||||
|
||||
auto b_f4_lo = type_convert<pk_fp4_t>(b_pack.unpack(number<0>{}));
|
||||
auto b_f4_hi = type_convert<pk_fp4_t>(b_pack.unpack(number<1>{}));
|
||||
|
||||
v_b_0 = type_convert<ComputeType>(b_f4_lo) * b_scale_fp4;
|
||||
v_b_1 = type_convert<ComputeType>(b_f4_hi) * b_scale_fp4;
|
||||
}
|
||||
|
||||
pasual = v_a_0 * v_b_0 + v_a_1 * v_b_1;
|
||||
v_acc += pasual;
|
||||
}
|
||||
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
|
||||
@@ -22,6 +22,7 @@ template <> struct DataTypeTraits<bf8_t> { static constexpr const char * name =
|
||||
template <> struct DataTypeTraits<int8_t> { static constexpr const char * name = "int8"; };
|
||||
template <> struct DataTypeTraits<pk_int4_t> { static constexpr const char * name = "pk_int4"; };
|
||||
template <> struct DataTypeTraits<pk_fp4_t> { static constexpr const char * name = "pk_fp4"; };
|
||||
template <> struct DataTypeTraits<pk_fp4_raw_t> { static constexpr const char * name = "pk_fp4_raw"; };
|
||||
|
||||
template <memory_operation_enum MemOp> struct memOpToStr;
|
||||
template <> struct memOpToStr<memory_operation_enum::set> { static constexpr const char * name = "set"; };
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/elementwise/binary_elementwise_operation.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
|
||||
|
||||
@@ -92,11 +92,17 @@ struct CShuffleEpilogue
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
|
||||
|
||||
using ATypeToUse =
|
||||
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
|
||||
using ATypeToUse = std::conditional_t<std::is_same_v<ADataType, pk_int4_t> ||
|
||||
std::is_same_v<ADataType, pk_fp4_t>,
|
||||
BDataType,
|
||||
ADataType>;
|
||||
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
using BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t> ||
|
||||
std::is_same_v<BDataType, pk_fp4_t> ||
|
||||
std::is_same_v<BDataType, pk_fp4_raw_t>,
|
||||
ADataType,
|
||||
BDataType>;
|
||||
|
||||
using ELayout = remove_cvref_t<typename Problem::ELayout>;
|
||||
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
|
||||
static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp"
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp"
|
||||
|
||||
@@ -96,8 +96,10 @@ struct BlockUniversalGemmAsBsCr
|
||||
|
||||
using ATypeToUse =
|
||||
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
using BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t> ||
|
||||
std::is_same_v<BDataType, pk_fp4_raw_t>,
|
||||
ADataType,
|
||||
BDataType>;
|
||||
|
||||
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
|
||||
|
||||
|
||||
@@ -17,10 +17,12 @@ struct GemmPipelineAgBgCrImplBase
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayout>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayout>>;
|
||||
using BInDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
using BDataType =
|
||||
std::conditional_t<std::is_same_v<BInDataType, pk_fp4_raw_t>, ADataType, BInDataType>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
@@ -270,12 +272,17 @@ struct GemmPipelineAgBgCrImplBase
|
||||
}();
|
||||
auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0});
|
||||
|
||||
using BLdsDataType =
|
||||
std::conditional_t<std::is_same_v<typename Problem::BDataType, pk_fp4_raw_t>,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
|
||||
auto b_lds_load_tile_distr = []() {
|
||||
if constexpr(is_b_load_tr)
|
||||
return make_static_tile_distribution(
|
||||
typename InputTileDistributionTraits<
|
||||
typename BLdsLoadTileDistr::DstrEncode,
|
||||
typename Problem::BDataType>::TransposedDstrEncode{});
|
||||
typename InputTileDistributionTraits<typename BLdsLoadTileDistr::DstrEncode,
|
||||
BLdsDataType>::TransposedDstrEncode{});
|
||||
|
||||
else
|
||||
return BLdsLoadTileDistr{};
|
||||
}();
|
||||
|
||||
@@ -303,8 +303,11 @@ struct UniversalGemmBasePolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using BDataType =
|
||||
std::conditional_t<std::is_same_v<typename Problem::BDataType, pk_fp4_raw_t>,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
@@ -585,9 +588,12 @@ struct UniversalGemmBasePolicy
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
|
||||
using BInDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
using BDataType = std::conditional_t<std::is_same_v<BInDataType, pk_fp4_raw_t>,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
|
||||
if constexpr(Problem::FixedVectorSize)
|
||||
{
|
||||
@@ -729,13 +735,17 @@ struct UniversalGemmBasePolicy
|
||||
{
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
constexpr index_t KPerBlock = std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>
|
||||
? Problem::BlockGemmShape::kK / 2
|
||||
: Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize =
|
||||
Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>();
|
||||
std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>
|
||||
? 4
|
||||
: (Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>());
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
using BLayout = remove_cvref_t<
|
||||
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
|
||||
using BLayout = remove_cvref_t<
|
||||
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
|
||||
// Tile: KPerBlock X NPerBlock
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -841,10 +851,12 @@ struct UniversalGemmBasePolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr index_t GetSmemSizeB()
|
||||
{
|
||||
constexpr index_t smem_size_b =
|
||||
integer_least_multiple(sizeof(typename Problem::BDataType) *
|
||||
Problem::BlockGemmShape::kN * Problem::BlockGemmShape::kK,
|
||||
16);
|
||||
using BDataType =
|
||||
std::conditional_t<std::is_same_v<typename Problem::BDataType, pk_fp4_raw_t>,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
constexpr index_t smem_size_b = integer_least_multiple(
|
||||
sizeof(BDataType) * Problem::BlockGemmShape::kN * Problem::BlockGemmShape::kK, 16);
|
||||
return smem_size_b;
|
||||
}
|
||||
|
||||
@@ -882,8 +894,10 @@ struct UniversalGemmPipelineAgBgCrPolicy
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using ATypeToUse =
|
||||
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
using BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t> ||
|
||||
std::is_same_v<BDataType, pk_fp4_raw_t>,
|
||||
ADataType,
|
||||
BDataType>;
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<ATypeToUse,
|
||||
BTypeToUse,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp"
|
||||
@@ -21,6 +20,9 @@
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp"
|
||||
|
||||
@@ -759,12 +759,20 @@ struct QuantGemmKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kargs.N, splitk_batch_offset.splitted_k / 2),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
else
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -900,10 +908,16 @@ struct QuantGemmKernel
|
||||
const auto& b_tensor_view = views.at(I2);
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / 2>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
else
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1046,10 +1060,17 @@ struct QuantGemmKernel
|
||||
{
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_n, 0});
|
||||
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
|
||||
return make_tile_window(
|
||||
b_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / 2>{}),
|
||||
{i_n, 0});
|
||||
else
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_n, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy>
|
||||
struct GemmMxFp4PipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Problem, Policy>
|
||||
{
|
||||
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
using ADataType = typename Base::ADataType;
|
||||
using ALayout = typename Base::ALayout;
|
||||
using BDataType = typename Base::BDataType;
|
||||
using BLayout = typename Base::BLayout;
|
||||
using BlockGemmShape = typename Base::BlockGemmShape;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t NPerBlockBQ = NPerBlock / QuantGroupSize::kN;
|
||||
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
|
||||
|
||||
static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= QuantGroupSize");
|
||||
static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= QuantGroupSize");
|
||||
|
||||
static_assert(NPerBlock % QuantGroupSize::kN == 0,
|
||||
"NPerBlock must be a multiple of QuantGroupSize::kN");
|
||||
static_assert(KPerBlock % QuantGroupSize::kK == 0,
|
||||
"KPerBlock must be a multiple of QuantGroupSize::kK");
|
||||
|
||||
// Create DRAM tile window for BQ
|
||||
template <typename BQDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
GetBQDramLoadWindow(const BQDramBlockWindowTmp& bq_dram_block_window_tmp) const
|
||||
{
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
|
||||
using YPerTile = number<NPerBlockBQ>;
|
||||
using XPerTile = number<KPerBlockBQ>;
|
||||
|
||||
auto bq_copy_dram_window =
|
||||
make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(YPerTile(), XPerTile()),
|
||||
bq_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeBQDramTileDistribution<Problem>());
|
||||
return bq_copy_dram_window;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,140 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "gemm_group_quant_utils.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy
|
||||
{
|
||||
using Base = UniversalGemmPipelineAgBgCrPolicy;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::I2;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeBQ()
|
||||
{
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
|
||||
|
||||
static_assert(std::is_same_v<BQLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
|
||||
return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlockBQ, KPerBlockBQ>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBRegTileDistribution()
|
||||
{
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize =
|
||||
Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
// Tile: KPerBlock X NPerBlock
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
using TileEncodingPattern =
|
||||
tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
getBTileAccessPattern(),
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
// Tile: NPerBlock X KPerBlock
|
||||
else
|
||||
{
|
||||
using TileEncodingPattern =
|
||||
tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
VecLoadSize,
|
||||
getBTileAccessPattern(),
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBQDramTileDistribution()
|
||||
{
|
||||
// using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t KScale = KPerBlock / Problem::QuantGroupSize::kK; // k_scale num //2
|
||||
constexpr index_t VecLoadSize =
|
||||
Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
constexpr index_t warp_size = get_warp_size();
|
||||
constexpr index_t num_warps = BlockSize / get_warp_size();
|
||||
constexpr index_t LargestVec = (KPerBlock * NPerBlock) / (num_warps * warp_size);
|
||||
constexpr index_t b_vec = VecLoadSize > LargestVec ? LargestVec : VecLoadSize;
|
||||
constexpr index_t K0 = KPerBlock / b_vec;
|
||||
constexpr index_t K1 = K0 / KScale;
|
||||
constexpr index_t K3 = K0 / K1;
|
||||
constexpr index_t K2 = 1;
|
||||
|
||||
constexpr index_t N0 = num_warps / NumWaveGroups;
|
||||
constexpr index_t N1 = warp_size / K0;
|
||||
constexpr index_t N2 = NPerBlock / (N0 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<K1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K3, K2>>,
|
||||
tuple<sequence<1>, sequence<1, 2, 0>>,
|
||||
tuple<sequence<0>, sequence<1, 0, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
|
||||
static_assert(Problem::QuantGroupSize::kK % WarpTile::at(I2) == 0,
|
||||
"KPerWarpGemm must be a multiple of QuantGroupSize!");
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
static_assert(std::is_same_v<typename Problem::ComputeDataType, fp8_t> ||
|
||||
std::is_same_v<typename Problem::ComputeDataType, bf8_t> ||
|
||||
std::is_same_v<typename Problem::ComputeDataType, bf16_t>);
|
||||
static_assert(std::is_same_v<typename Problem::CDataType, float>);
|
||||
|
||||
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<
|
||||
typename Problem::ADataType,
|
||||
std::conditional_t<std::is_same_v<typename Problem::BDataType, pk_fp4_raw_t>,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
|
||||
return BlockUniversalGemmAsBsCr<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,665 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
|
||||
template <typename Problem, typename Policy = GemmMxFp4PipelineAgBgCrPolicy>
|
||||
struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
{
|
||||
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
|
||||
using PipelineImplBase = GemmMxFp4PipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BDqDataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
|
||||
|
||||
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
|
||||
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
using I2 = number<2>;
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDqDataType>>::PackedSize;
|
||||
|
||||
static constexpr index_t BQPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BQDataType>>::PackedSize;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / QuantGroupSize::kN;
|
||||
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize::kK;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeBQ()
|
||||
{
|
||||
return Policy::template GetVectorSizeBQ<Problem>();
|
||||
}
|
||||
|
||||
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
|
||||
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
using Base::PrefetchStages;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
|
||||
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
|
||||
return concat('_', "mxfp4gemm_pipeline_AgBgCrCompV3",
|
||||
concat('x', MPerBlock, NPerBlock, KPerBlock), BlockSize,
|
||||
concat('x', WaveNumM, WaveNumN),
|
||||
concat('x', kPadM, kPadN, kPadK),
|
||||
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST static std::string Print()
|
||||
{
|
||||
constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
|
||||
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
|
||||
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
|
||||
|
||||
constexpr index_t WaveSize = 64;
|
||||
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
|
||||
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
|
||||
|
||||
constexpr index_t A_LDS_Read_Width = GetSmemPackA();
|
||||
constexpr index_t B_LDS_Read_Width = GetSmemPackB();
|
||||
|
||||
constexpr index_t A_LDS_Write_Width = GetSmemPackA();
|
||||
constexpr index_t B_LDS_Write_Width = GetSmemPackB();
|
||||
|
||||
constexpr index_t A_Buffer_Load_Inst_Num =
|
||||
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
|
||||
constexpr index_t B_Buffer_Load_Inst_Num =
|
||||
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
|
||||
constexpr index_t BQ_Buffer_Load_Inst_Num =
|
||||
NPerBlock * KPerBlockBQ / (BlockSize * GetVectorSizeBQ());
|
||||
|
||||
constexpr index_t A_LDS_Write_Inst_Num =
|
||||
MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
|
||||
constexpr index_t B_LDS_Write_Inst_Num =
|
||||
NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);
|
||||
|
||||
constexpr index_t A_LDS_Read_Inst_Num =
|
||||
WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
|
||||
constexpr index_t B_LDS_Read_Inst_Num =
|
||||
WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
|
||||
|
||||
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
|
||||
(BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
|
||||
|
||||
auto str = std::stringstream{};
|
||||
|
||||
str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << ", "
|
||||
<< "BQ vector size: " << GetVectorSizeBQ() << "\n"
|
||||
<< "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n"
|
||||
<< "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num
|
||||
<< ", " << "BQ buffer load inst: " << BQ_Buffer_Load_Inst_Num << "\n"
|
||||
<< "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num
|
||||
<< "\n"
|
||||
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
|
||||
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
|
||||
<< "QuantGroupSize: " << QuantGroupSize::GetName() << "\n"
|
||||
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
|
||||
<< "PrefetchStages: " << PrefetchStages << "\n";
|
||||
return str.str();
|
||||
}
|
||||
|
||||
template <GemmPipelineScheduler Scheduler>
|
||||
struct PipelineImpl : public PipelineImplBase
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
|
||||
{
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
|
||||
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
|
||||
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
|
||||
|
||||
constexpr index_t WaveSize = 64;
|
||||
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
|
||||
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
|
||||
|
||||
// Below should be equal to AK1|BK1
|
||||
constexpr index_t A_LDS_Read_Width = GetSmemPackA();
|
||||
constexpr index_t B_LDS_Read_Width = GetSmemPackB();
|
||||
|
||||
constexpr index_t A_LDS_Write_Width = GetSmemPackA();
|
||||
constexpr index_t B_LDS_Write_Width = GetSmemPackB();
|
||||
|
||||
constexpr index_t A_Buffer_Load_Inst_Num =
|
||||
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
|
||||
constexpr index_t B_Buffer_Load_Inst_Num =
|
||||
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
|
||||
|
||||
constexpr index_t A_LDS_Write_Inst_Num =
|
||||
MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
|
||||
constexpr index_t B_LDS_Write_Inst_Num =
|
||||
NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);
|
||||
|
||||
constexpr index_t A_LDS_Read_Inst_Num =
|
||||
WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
|
||||
constexpr index_t B_LDS_Read_Inst_Num =
|
||||
WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
|
||||
|
||||
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
|
||||
(BlockSize / WaveSize) /
|
||||
(MPerXDL * NPerXDL * KPerXDL);
|
||||
|
||||
// A/B split schedule
|
||||
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
|
||||
constexpr auto num_ds_read_inst_a =
|
||||
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num
|
||||
: A_LDS_Read_Inst_Num / 2;
|
||||
constexpr auto num_ds_read_inst_b =
|
||||
B_LDS_Read_Width * sizeof(BDqDataType) / BPackedSize == 16
|
||||
? B_LDS_Read_Inst_Num
|
||||
: B_LDS_Read_Inst_Num / 2;
|
||||
|
||||
constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num;
|
||||
constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num;
|
||||
|
||||
constexpr auto num_buffer_load_inst_a = A_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num;
|
||||
|
||||
constexpr auto num_mfma_inst = C_MFMA_Inst_Num;
|
||||
|
||||
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
|
||||
constexpr auto ds_read_a_issue_cycle =
|
||||
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_b_issue_cycle =
|
||||
B_LDS_Read_Width * sizeof(BDqDataType) / BPackedSize == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_a_mfma_rate =
|
||||
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
|
||||
constexpr auto ds_read_b_mfma_rate =
|
||||
(mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
|
||||
|
||||
constexpr auto num_dsread_a_mfma =
|
||||
(num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
|
||||
constexpr auto num_dsread_b_mfma =
|
||||
(num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
|
||||
|
||||
// stage 1
|
||||
// Separate this part?
|
||||
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
|
||||
// sizeof(ComputeDataType) /
|
||||
// sizeof(BDataType)
|
||||
// ? sizeof(ComputeDataType) /
|
||||
// sizeof(ADataType) : sizeof(ComputeDataType)
|
||||
// / sizeof(BDataType);
|
||||
constexpr auto num_mfma_stage1 =
|
||||
num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
|
||||
constexpr auto num_mfma_per_issue =
|
||||
num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
|
||||
constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
|
||||
constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
|
||||
|
||||
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
|
||||
ignore = idswrite;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
|
||||
});
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
|
||||
ignore = idswrite;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA
|
||||
});
|
||||
|
||||
// stage 2
|
||||
static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
|
||||
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
|
||||
ds_read_a_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100,
|
||||
num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
|
||||
static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) {
|
||||
if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
|
||||
ds_read_b_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100,
|
||||
num_ds_read_inst_b - (num_dsread_b_mfma - 1) * ds_read_b_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
}
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename BQDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
remove_cvref_t<typename BDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BQDataType,
|
||||
remove_cvref_t<typename BQDramBlockWindowTmp::DataType>>,
|
||||
"A/B/BQ Dram block window should have the same data type as appropriate "
|
||||
"([A|B|BQ]DataType) defined in Problem definition!");
|
||||
|
||||
constexpr bool is_a_col_major =
|
||||
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_bq_col_major =
|
||||
std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)");
|
||||
static_assert(NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
|
||||
"Bq block window has incorrect lengths for defined BqLayout!");
|
||||
|
||||
static_assert(is_a_col_major
|
||||
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"A block window has incorrect lengths for defined ALayout!");
|
||||
static_assert(
|
||||
is_b_row_major
|
||||
? (KPerBlock / 2 == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlock / 2 == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
|
||||
// ------------------------------------------------------------------------------------
|
||||
// Definitions of all needed tiles
|
||||
// int b_block_stride = 0;
|
||||
// A/B tiles in LDS
|
||||
auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
|
||||
|
||||
// Tile distribution for load from lds
|
||||
constexpr auto a_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
constexpr auto b_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
|
||||
|
||||
// A DRAM tile window for load
|
||||
// A LDS tile window for store
|
||||
// A LDS tile for block GEMM
|
||||
auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] =
|
||||
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
|
||||
|
||||
// B DRAM tile window for load, (kN, kK/2)
|
||||
// B LDS tile window for store, (kN, kK)
|
||||
// B LDS tile for block GEMM
|
||||
auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
|
||||
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
|
||||
|
||||
// B scale DRAM tile window for load
|
||||
// auto b_scale_copy_dram_window =
|
||||
// make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
// bq_dram_block_window_tmp.get_window_lengths(),
|
||||
// bq_dram_block_window_tmp.get_window_origin(),
|
||||
// Policy::template GetBQDramLoadWindow<Problem>());
|
||||
auto bq_copy_dram_window = Base::GetBQDramLoadWindow(bq_dram_block_window_tmp);
|
||||
|
||||
auto bq_block_tile = decltype(load_tile(bq_copy_dram_window)){};
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
auto c_block_tile = block_gemm.MakeCBlockTile();
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
// using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
using ABlockTile =
|
||||
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
using BBlockTile =
|
||||
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
|
||||
|
||||
ABlockTile a_block_tile;
|
||||
BBlockTile b_fp4_block_tile;
|
||||
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step =
|
||||
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(KPerBlock / 2, 0) : make_array(0, KPerBlock / 2);
|
||||
|
||||
constexpr index_t b_scale_dram_tile_window_step = KPerBlock / QuantGroupSize::kK;
|
||||
// -----------------------------------------------------------------------------------------
|
||||
// Gemm pipeline start
|
||||
|
||||
// prefetch
|
||||
// global read 0
|
||||
// auto a_scale_block_tile = decltype(load_tile(a_scale_copy_dram_window)){};
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
// BDataType
|
||||
auto b_block_tile = make_static_distributed_tensor<BDqDataType>(
|
||||
Policy::template MakeBRegTileDistribution<Problem>());
|
||||
|
||||
bq_block_tile = load_tile(bq_copy_dram_window);
|
||||
move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step});
|
||||
|
||||
constexpr auto idx1_js = tile_distributed_index<0>{};
|
||||
constexpr auto b_block = decltype(b_fp4_block_tile)::get_distributed_spans();
|
||||
sweep_tile_span(b_block[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(b_block[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js);
|
||||
auto b_scale_uint = type_convert<int32_t>(bq_block_tile(i_j_idx_scale)) - 127;
|
||||
auto b_scale = type_convert<float>(std::pow(2.0f, b_scale_uint));
|
||||
constexpr auto idx1_lo = tile_distributed_index<idx1.impl_.at(0) * 2>{};
|
||||
constexpr auto idx1_hi = tile_distributed_index<idx1.impl_.at(0) * 2 + 1>{};
|
||||
constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo);
|
||||
constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi);
|
||||
|
||||
auto b_pack = type_convert<pk_fp4_t>(b_fp4_block_tile(i_j_idx));
|
||||
auto b_f4_lo = type_convert<pk_fp4_t>(b_pack.unpack(number<0>{}));
|
||||
auto b_f4_hi = type_convert<pk_fp4_t>(b_pack.unpack(number<1>{}));
|
||||
b_block_tile(i_j_idx_lo) =
|
||||
type_convert<bf16_t>(type_convert<float>(b_f4_lo) * b_scale);
|
||||
b_block_tile(i_j_idx_hi) =
|
||||
type_convert<bf16_t>(type_convert<float>(b_f4_hi) * b_scale);
|
||||
});
|
||||
});
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write 0
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
|
||||
}
|
||||
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDqDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
|
||||
}
|
||||
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
bq_block_tile = load_tile(bq_copy_dram_window);
|
||||
move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step});
|
||||
|
||||
sweep_tile_span(b_block[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(b_block[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js);
|
||||
|
||||
auto b_scale_uint = type_convert<int32_t>(bq_block_tile(i_j_idx_scale)) - 127;
|
||||
auto b_scale = type_convert<float>(std::pow(2.0f, b_scale_uint));
|
||||
constexpr auto idx1_lo = tile_distributed_index<idx1.impl_.at(0) * 2>{};
|
||||
constexpr auto idx1_hi = tile_distributed_index<idx1.impl_.at(0) * 2 + 1>{};
|
||||
constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo);
|
||||
constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi);
|
||||
|
||||
auto b_pack = type_convert<pk_fp4_t>(b_fp4_block_tile(i_j_idx));
|
||||
auto b_f4_lo = type_convert<pk_fp4_t>(b_pack.unpack(number<0>{}));
|
||||
auto b_f4_hi = type_convert<pk_fp4_t>(b_pack.unpack(number<1>{}));
|
||||
b_block_tile(i_j_idx_lo) =
|
||||
type_convert<bf16_t>(type_convert<float>(b_f4_lo) * b_scale);
|
||||
b_block_tile(i_j_idx_hi) =
|
||||
type_convert<bf16_t>(type_convert<float>(b_f4_hi) * b_scale);
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// main body
|
||||
if constexpr(HasHotLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDqDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
|
||||
}
|
||||
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(
|
||||
b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
bq_block_tile = load_tile(bq_copy_dram_window);
|
||||
move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step});
|
||||
|
||||
sweep_tile_span(b_block[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(b_block[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js);
|
||||
|
||||
auto b_scale_uint =
|
||||
type_convert<int32_t>(bq_block_tile(i_j_idx_scale)) - 127;
|
||||
auto b_scale = type_convert<float>(std::pow(2.0f, b_scale_uint));
|
||||
constexpr auto idx1_lo = tile_distributed_index<idx1.impl_.at(0) * 2>{};
|
||||
constexpr auto idx1_hi =
|
||||
tile_distributed_index<idx1.impl_.at(0) * 2 + 1>{};
|
||||
constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo);
|
||||
constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi);
|
||||
|
||||
auto b_pack = type_convert<pk_fp4_t>(b_fp4_block_tile(i_j_idx));
|
||||
auto b_f4_lo = type_convert<pk_fp4_t>(b_pack.unpack(number<0>{}));
|
||||
auto b_f4_hi = type_convert<pk_fp4_t>(b_pack.unpack(number<1>{}));
|
||||
b_block_tile(i_j_idx_lo) =
|
||||
type_convert<bf16_t>(type_convert<float>(b_f4_lo) * b_scale);
|
||||
b_block_tile(i_j_idx_hi) =
|
||||
type_convert<bf16_t>(type_convert<float>(b_f4_hi) * b_scale);
|
||||
});
|
||||
});
|
||||
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
i += 1;
|
||||
// b_block_stride +=1;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
// tile_elementwise_inout([](auto& c) { c = 0; }, acc_block_tile);
|
||||
// tail
|
||||
if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd))
|
||||
{
|
||||
// Leak last MFMA block to epilogue region, cover the potential lds-shuffle
|
||||
// latency
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDqDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
|
||||
}
|
||||
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_sync_lds();
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
return c_block_tile;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief This function runs the pipeline using compile-time known hot loop and tail number.
|
||||
* @param num_loop The number of loop iterations. This is determined at runtime due to e.g.
|
||||
* SplitK.
|
||||
* @note This is used by the kernel variants that are able to determine
|
||||
* hot loop and tail number on the host side, e.g. non-persistent gemm kernel.
|
||||
*/
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename BQDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem,
|
||||
index_t n = 0) const
|
||||
{
|
||||
ck_tile::ignore = n;
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDqDataType& b) { return b; },
|
||||
bq_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user