mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 00:04:37 +00:00
Merge branch 'develop' into ck_tile/fa_train
This commit is contained in:
42
.azuredevops/rocm-ci.yml
Normal file
42
.azuredevops/rocm-ci.yml
Normal file
@@ -0,0 +1,42 @@
|
||||
resources:
|
||||
repositories:
|
||||
- repository: pipelines_repo
|
||||
type: github
|
||||
endpoint: ROCm
|
||||
name: ROCm/ROCm
|
||||
|
||||
variables:
|
||||
- group: common
|
||||
- template: /.azuredevops/variables-global.yml@pipelines_repo
|
||||
|
||||
trigger:
|
||||
batch: true
|
||||
branches:
|
||||
include:
|
||||
- develop
|
||||
paths:
|
||||
exclude:
|
||||
- .github
|
||||
- docs
|
||||
- '.*.y*ml'
|
||||
- '*.md'
|
||||
- Jenkinsfile
|
||||
- LICENSE
|
||||
|
||||
pr:
|
||||
autoCancel: true
|
||||
branches:
|
||||
include:
|
||||
- develop
|
||||
paths:
|
||||
exclude:
|
||||
- .github
|
||||
- docs
|
||||
- '.*.y*ml'
|
||||
- '*.md'
|
||||
- Jenkinsfile
|
||||
- LICENSE
|
||||
drafts: false
|
||||
|
||||
jobs:
|
||||
- template: ${{ variables.CI_COMPONENT_PATH }}/composable_kernel.yml@pipelines_repo
|
||||
@@ -23,7 +23,7 @@ endif()
|
||||
|
||||
set(version 1.1.0)
|
||||
# Check support for CUDA/HIP in Cmake
|
||||
project(composable_kernel VERSION ${version} LANGUAGES CXX)
|
||||
project(composable_kernel VERSION ${version} LANGUAGES CXX HIP)
|
||||
include(CTest)
|
||||
|
||||
find_package(Python3 3.6 COMPONENTS Interpreter REQUIRED)
|
||||
@@ -112,7 +112,7 @@ message("checking which targets are supported")
|
||||
#Setting GPU_TARGETS on command line will override this list
|
||||
if(NOT PROFILER_ONLY)
|
||||
rocm_check_target_ids(DEFAULT_GPU_TARGETS
|
||||
TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102")
|
||||
TARGETS "gfx900;gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102")
|
||||
else()
|
||||
add_definitions(-DPROFILER_ONLY)
|
||||
set(GPU_TARGETS "" CACHE STRING "" FORCE)
|
||||
@@ -135,12 +135,10 @@ endif()
|
||||
|
||||
message("Supported GPU_TARGETS= ${DEFAULT_GPU_TARGETS}")
|
||||
|
||||
set(AMDGPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE)
|
||||
|
||||
if(GPU_TARGETS)
|
||||
message("Building CK for the following targets: ${GPU_TARGETS}")
|
||||
else()
|
||||
message("Building CK for the following targets: ${AMDGPU_TARGETS}")
|
||||
message("Building CK for the default targets: ${DEFAULT_GPU_TARGETS}")
|
||||
endif()
|
||||
|
||||
if (GPU_TARGETS)
|
||||
@@ -225,7 +223,13 @@ link_libraries(Threads::Threads)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}")
|
||||
message("CMAKE_CXX_COMPILER: ${CMAKE_CXX_COMPILER}")
|
||||
|
||||
## HIP
|
||||
set(CMAKE_HIP_PLATFORM amd)
|
||||
set(CMAKE_HIP_COMPILER ${CMAKE_CXX_COMPILER})
|
||||
set(CMAKE_HIP_EXTENSIONS ON)
|
||||
message("CMAKE_HIP_COMPILER: ${CMAKE_HIP_COMPILER}")
|
||||
|
||||
## OpenMP
|
||||
if(CMAKE_CXX_COMPILER_ID MATCHES "Clang")
|
||||
|
||||
@@ -4,4 +4,22 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
|
||||
add_executable(client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp)
|
||||
target_link_libraries(client_grouped_conv1d_fwd PRIVATE composable_kernel::device_conv_operations)
|
||||
endif()
|
||||
|
||||
if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES)
|
||||
add_executable(client_grouped_conv3d_fwd_fp8 grouped_conv3d_fwd_fp8.cpp)
|
||||
target_link_libraries(client_grouped_conv3d_fwd_fp8 PRIVATE composable_kernel::device_conv_operations)
|
||||
endif()
|
||||
|
||||
if((DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES)
|
||||
add_executable(client_grouped_conv3d_fwd_bf8 grouped_conv3d_fwd_bf8.cpp)
|
||||
target_link_libraries(client_grouped_conv3d_fwd_bf8 PRIVATE composable_kernel::device_conv_operations)
|
||||
endif()
|
||||
|
||||
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES)
|
||||
add_executable(client_grouped_conv3d_fwd_fp8_bf8 grouped_conv3d_fwd_fp8_bf8.cpp)
|
||||
target_link_libraries(client_grouped_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations)
|
||||
|
||||
add_executable(client_grouped_conv3d_fwd_bf8_fp8 grouped_conv3d_fwd_bf8_fp8.cpp)
|
||||
target_link_libraries(client_grouped_conv3d_fwd_bf8_fp8 PRIVATE composable_kernel::device_conv_operations)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
304
client_example/07_grouped_convnd_fwd/common.hpp
Normal file
304
client_example/07_grouped_convnd_fwd/common.hpp
Normal file
@@ -0,0 +1,304 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
struct SimpleDeviceMem
|
||||
{
|
||||
SimpleDeviceMem() = delete;
|
||||
|
||||
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
|
||||
{
|
||||
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
|
||||
}
|
||||
|
||||
void* GetDeviceBuffer() { return p_mem_; }
|
||||
|
||||
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
|
||||
|
||||
void* p_mem_;
|
||||
};
|
||||
|
||||
template <ck::index_t NumDimSpatial, ck::index_t NumNonSpatialDim = 3>
|
||||
std::size_t
|
||||
GetFlops(const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& output_lengths,
|
||||
const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& weights_lengths)
|
||||
{
|
||||
// 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
|
||||
ck::index_t G = weights_lengths[0];
|
||||
ck::index_t N = output_lengths[1];
|
||||
ck::index_t K = weights_lengths[1];
|
||||
ck::index_t C = weights_lengths[2];
|
||||
|
||||
return static_cast<std::size_t>(2) * G * N * K * C *
|
||||
std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim),
|
||||
std::end(output_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<>()) *
|
||||
std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim),
|
||||
std::end(weights_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<>());
|
||||
}
|
||||
|
||||
template <typename InDataType, ck::index_t NumDimSpatial, ck::index_t NumNonSpatialDim = 3>
|
||||
std::size_t
|
||||
GetInputByte(const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& input_lengths)
|
||||
{
|
||||
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
|
||||
return sizeof(InDataType) * std::accumulate(std::begin(input_lengths),
|
||||
std::end(input_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<>());
|
||||
}
|
||||
|
||||
template <typename WeiDataType, ck::index_t NumDimSpatial, ck::index_t NumNonSpatialDim = 3>
|
||||
std::size_t
|
||||
GetWeightByte(const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& weights_lengths)
|
||||
{
|
||||
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
|
||||
return sizeof(WeiDataType) * std::accumulate(std::begin(weights_lengths),
|
||||
std::end(weights_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<>());
|
||||
}
|
||||
|
||||
template <typename OutDataType, ck::index_t NumDimSpatial, ck::index_t NumNonSpatialDim = 3>
|
||||
std::size_t
|
||||
GetOutputByte(const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& output_lengths)
|
||||
{
|
||||
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
|
||||
return sizeof(OutDataType) * std::accumulate(std::begin(output_lengths),
|
||||
std::end(output_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<std::size_t>());
|
||||
}
|
||||
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
ck::index_t NumNonSpatialDim = 3,
|
||||
typename AComputeType = InDataType,
|
||||
typename BComputeType = AComputeType>
|
||||
bool run_grouped_conv_fwd(std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> in_lengths,
|
||||
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> wei_lengths,
|
||||
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> out_lengths)
|
||||
{
|
||||
std::size_t in_mem_size = GetInputByte<InDataType, NumDimSpatial>(in_lengths);
|
||||
std::size_t wei_mem_size = GetWeightByte<WeiDataType, NumDimSpatial>(wei_lengths);
|
||||
std::size_t out_mem_size = GetOutputByte<OutDataType, NumDimSpatial>(out_lengths);
|
||||
|
||||
SimpleDeviceMem in(in_mem_size);
|
||||
SimpleDeviceMem wei(wei_mem_size);
|
||||
SimpleDeviceMem out(out_mem_size);
|
||||
|
||||
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> in_strides;
|
||||
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> wei_strides;
|
||||
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> out_strides;
|
||||
in_strides.fill(0);
|
||||
wei_strides.fill(0);
|
||||
out_strides.fill(0);
|
||||
in_strides.back() = 1;
|
||||
wei_strides.back() = 1;
|
||||
out_strides.back() = 1;
|
||||
|
||||
std::partial_sum(rbegin(in_lengths),
|
||||
std::prev(rend(in_lengths)),
|
||||
std::next(rbegin(in_strides)),
|
||||
std::multiplies<>{});
|
||||
std::partial_sum(rbegin(wei_lengths),
|
||||
std::prev(rend(wei_lengths)),
|
||||
std::next(rbegin(wei_strides)),
|
||||
std::multiplies<>{});
|
||||
std::partial_sum(rbegin(out_lengths),
|
||||
std::prev(rend(out_lengths)),
|
||||
std::next(rbegin(out_strides)),
|
||||
std::multiplies<>{});
|
||||
|
||||
// transpose NDHWGC/KZYXGC/NDHWGK to GNDHWC/GKZYXC/GNDHWK to GNCDHW/GKCZYX/GNKDHW
|
||||
std::rotate(std::next(rbegin(in_lengths)), std::next(rbegin(in_lengths), 2), rend(in_lengths));
|
||||
std::rotate(rbegin(in_lengths),
|
||||
std::next(rbegin(in_lengths)),
|
||||
std::next(rbegin(in_lengths), NumDimSpatial + 1));
|
||||
|
||||
std::rotate(std::next(rbegin(in_strides)), std::next(rbegin(in_strides), 2), rend(in_strides));
|
||||
std::rotate(rbegin(in_strides),
|
||||
std::next(rbegin(in_strides)),
|
||||
std::next(rbegin(in_strides), NumDimSpatial + 1));
|
||||
|
||||
std::rotate(rbegin(wei_lengths),
|
||||
std::next(rbegin(wei_lengths)),
|
||||
std::next(rbegin(wei_lengths), NumDimSpatial + 1));
|
||||
|
||||
std::rotate(rbegin(wei_strides),
|
||||
std::next(rbegin(wei_strides)),
|
||||
std::next(rbegin(wei_strides), NumDimSpatial + 1));
|
||||
|
||||
std::rotate(
|
||||
std::next(rbegin(out_lengths)), std::next(rbegin(out_lengths), 2), rend(out_lengths));
|
||||
std::rotate(rbegin(out_lengths),
|
||||
std::next(rbegin(out_lengths)),
|
||||
std::next(rbegin(out_lengths), NumDimSpatial + 1));
|
||||
|
||||
std::rotate(
|
||||
std::next(rbegin(out_strides)), std::next(rbegin(out_strides), 2), rend(out_strides));
|
||||
std::rotate(rbegin(out_strides),
|
||||
std::next(rbegin(out_strides)),
|
||||
std::next(rbegin(out_strides), NumDimSpatial + 1));
|
||||
|
||||
std::array<ck::index_t, NumDimSpatial> conv_filter_strides;
|
||||
std::array<ck::index_t, NumDimSpatial> conv_filter_dilations;
|
||||
std::array<ck::index_t, NumDimSpatial> input_left_pads;
|
||||
std::array<ck::index_t, NumDimSpatial> input_right_pads;
|
||||
conv_filter_strides.fill(1);
|
||||
conv_filter_dilations.fill(1);
|
||||
input_left_pads.fill(1);
|
||||
input_right_pads.fill(1);
|
||||
|
||||
std::size_t flop = GetFlops<NumDimSpatial>(out_lengths, wei_lengths);
|
||||
std::size_t num_bytes = in_mem_size + wei_mem_size + out_mem_size;
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
ck::Tuple<>,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
ck::Tuple<>,
|
||||
OutDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AComputeType,
|
||||
BComputeType>;
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
std::string best_op_name;
|
||||
int best_op_id = -1;
|
||||
float best_avg_time = std::numeric_limits<float>::max();
|
||||
float best_gb_per_sec = 0;
|
||||
float best_tflops = 0;
|
||||
|
||||
// profile device operation instances
|
||||
std::cout << "Run all instances and do timing" << std::endl;
|
||||
|
||||
for(int i = 0; i < op_ptrs.size(); ++i)
|
||||
{
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
in.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
std::array<const void*, 0>{},
|
||||
out.GetDeviceBuffer(),
|
||||
in_lengths,
|
||||
in_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
std::array<std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>, 0>{{}},
|
||||
std::array<std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>, 0>{{}},
|
||||
out_lengths,
|
||||
out_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
float gb_per_sec = num_bytes / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_op_id = i;
|
||||
best_op_name = op_name;
|
||||
best_avg_time = avg_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
best_tflops = tflops;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << op_name << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if(best_op_id < 0)
|
||||
{
|
||||
std::cerr << "no suitable instance" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops
|
||||
<< " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
|
||||
|
||||
// run the best intance
|
||||
{
|
||||
auto& op_ptr = op_ptrs[best_op_id];
|
||||
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
|
||||
<< std::endl;
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
in.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
std::array<const void*, 0>{},
|
||||
out.GetDeviceBuffer(),
|
||||
in_lengths,
|
||||
in_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
std::array<std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>, 0>{{}},
|
||||
std::array<std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>, 0>{{}},
|
||||
out_lengths,
|
||||
out_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
|
||||
}
|
||||
|
||||
std::cout << "Done" << std::endl;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@@ -1,17 +1,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
using InDataType = ck::half_t;
|
||||
using WeiDataType = ck::half_t;
|
||||
@@ -31,199 +24,16 @@ static constexpr ck::index_t X = 3;
|
||||
static constexpr ck::index_t Wi = 28;
|
||||
static constexpr ck::index_t Wo = 28;
|
||||
|
||||
struct SimpleDeviceMem
|
||||
{
|
||||
SimpleDeviceMem() = delete;
|
||||
|
||||
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
|
||||
{
|
||||
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
|
||||
}
|
||||
|
||||
void* GetDeviceBuffer() { return p_mem_; }
|
||||
|
||||
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
|
||||
|
||||
void* p_mem_;
|
||||
};
|
||||
|
||||
int main()
|
||||
{
|
||||
std::array<ck::index_t, NumDimSpatial + 3> in_lengths{G, N, Wi, C};
|
||||
std::array<ck::index_t, NumDimSpatial + 3> in_strides{0, 0, 0, 1};
|
||||
|
||||
std::array<ck::index_t, NumDimSpatial + 3> wei_lengths{G, K, X, C};
|
||||
std::array<ck::index_t, NumDimSpatial + 3> wei_strides{0, 0, 0, 1};
|
||||
|
||||
std::array<ck::index_t, NumDimSpatial + 3> out_lengths{G, N, Wo, K};
|
||||
std::array<ck::index_t, NumDimSpatial + 3> out_strides{0, 0, 0, 1};
|
||||
|
||||
std::partial_sum(rbegin(in_lengths),
|
||||
std::prev(rend(in_lengths)),
|
||||
std::next(rbegin(in_strides)),
|
||||
std::multiplies<>{});
|
||||
std::partial_sum(rbegin(wei_lengths),
|
||||
std::prev(rend(wei_lengths)),
|
||||
std::next(rbegin(wei_strides)),
|
||||
std::multiplies<>{});
|
||||
std::partial_sum(rbegin(out_lengths),
|
||||
std::prev(rend(out_lengths)),
|
||||
std::next(rbegin(out_strides)),
|
||||
std::multiplies<>{});
|
||||
|
||||
// transpose GNWC/GKXC/GNWK to GNCW/GKCX/GNCW
|
||||
std::rotate(rbegin(in_lengths),
|
||||
std::next(rbegin(in_lengths)),
|
||||
std::next(rbegin(in_lengths), NumDimSpatial + 1));
|
||||
std::rotate(rbegin(in_strides),
|
||||
std::next(rbegin(in_strides)),
|
||||
std::next(rbegin(in_strides), NumDimSpatial + 1));
|
||||
std::rotate(rbegin(wei_lengths),
|
||||
std::next(rbegin(wei_lengths)),
|
||||
std::next(rbegin(wei_lengths), NumDimSpatial + 1));
|
||||
std::rotate(rbegin(wei_strides),
|
||||
std::next(rbegin(wei_strides)),
|
||||
std::next(rbegin(wei_strides), NumDimSpatial + 1));
|
||||
std::rotate(rbegin(out_lengths),
|
||||
std::next(rbegin(out_lengths)),
|
||||
std::next(rbegin(out_lengths), NumDimSpatial + 1));
|
||||
std::rotate(rbegin(out_strides),
|
||||
std::next(rbegin(out_strides)),
|
||||
std::next(rbegin(out_strides), NumDimSpatial + 1));
|
||||
|
||||
std::array<ck::index_t, NumDimSpatial> filter_strides{1};
|
||||
std::array<ck::index_t, NumDimSpatial> filter_dilations{1};
|
||||
std::array<ck::index_t, NumDimSpatial> input_left_pads{1};
|
||||
std::array<ck::index_t, NumDimSpatial> input_right_pads{1};
|
||||
|
||||
SimpleDeviceMem in(sizeof(InDataType) * G * N * Wi * C);
|
||||
SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * X * C);
|
||||
SimpleDeviceMem out(sizeof(OutDataType) * G * N * Wo * K);
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
ck::Tuple<>,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
ck::Tuple<>,
|
||||
OutDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
std::string best_op_name;
|
||||
int best_op_id = -1;
|
||||
float best_avg_time = std::numeric_limits<float>::max();
|
||||
float best_gb_per_sec = 0;
|
||||
float best_tflops = 0;
|
||||
|
||||
// profile device operation instances
|
||||
std::cout << "Run all instances and do timing" << std::endl;
|
||||
|
||||
for(int i = 0; i < op_ptrs.size(); ++i)
|
||||
{
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{},
|
||||
out.GetDeviceBuffer(),
|
||||
in_lengths,
|
||||
in_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
out_lengths,
|
||||
out_strides,
|
||||
filter_strides,
|
||||
filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
|
||||
|
||||
std::size_t flop = std::size_t(2) * G * N * K * C * Wo * X;
|
||||
std::size_t num_bytes = sizeof(InDataType) * G * N * Wi * C +
|
||||
sizeof(WeiDataType) * G * K * X * C +
|
||||
sizeof(OutDataType) * G * N * Wo * K;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
float gb_per_sec = num_bytes / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_op_id = i;
|
||||
best_op_name = op_name;
|
||||
best_avg_time = avg_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
best_tflops = tflops;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << op_name << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if(best_op_id < 0)
|
||||
{
|
||||
std::cerr << "no suitable instance" << std::endl;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops
|
||||
<< " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
|
||||
|
||||
// run the best intance
|
||||
{
|
||||
auto& op_ptr = op_ptrs[best_op_id];
|
||||
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
|
||||
<< std::endl;
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{},
|
||||
out.GetDeviceBuffer(),
|
||||
in_lengths,
|
||||
in_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
out_lengths,
|
||||
out_strides,
|
||||
filter_strides,
|
||||
filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
|
||||
}
|
||||
|
||||
std::cout << "Done" << std::endl;
|
||||
}
|
||||
return run_grouped_conv_fwd<NumDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
3>({N, Wi, G, C}, {G, K, X, C}, {N, Wo, G, K})
|
||||
? EXIT_SUCCESS
|
||||
: EXIT_FAILURE;
|
||||
}
|
||||
|
||||
@@ -1,17 +1,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
using InDataType = ck::half_t;
|
||||
using WeiDataType = ck::half_t;
|
||||
@@ -34,167 +27,16 @@ static constexpr ck::index_t Wi = 28; // input W
|
||||
static constexpr ck::index_t Ho = 28; // output H
|
||||
static constexpr ck::index_t Wo = 28; // output W
|
||||
|
||||
struct SimpleDeviceMem
|
||||
{
|
||||
SimpleDeviceMem() = delete;
|
||||
|
||||
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
|
||||
{
|
||||
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
|
||||
}
|
||||
|
||||
void* GetDeviceBuffer() { return p_mem_; }
|
||||
|
||||
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
|
||||
|
||||
void* p_mem_;
|
||||
};
|
||||
|
||||
int main()
|
||||
{
|
||||
// We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space
|
||||
// However, CK's API only accept length and stride with order of GNCHW/GKCYX/GNCHW
|
||||
// Hence, we need to adjust the order of stride
|
||||
std::array<ck::index_t, 5> in_lengths{G, N, C, Hi, Wi};
|
||||
std::array<ck::index_t, 5> in_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C};
|
||||
std::array<ck::index_t, 5> wei_lengths{G, K, C, Y, X};
|
||||
std::array<ck::index_t, 5> wei_strides{K * Y * X * C, Y * X * C, 1, X * C, C};
|
||||
std::array<ck::index_t, 5> out_lengths{G, N, K, Ho, Wo};
|
||||
std::array<ck::index_t, 5> out_strides{C, Ho * Wo * G * C, 1, Wo * G * C, G * C};
|
||||
|
||||
std::array<ck::index_t, NumDimSpatial> filter_strides{1, 1};
|
||||
std::array<ck::index_t, NumDimSpatial> filter_dilations{1, 1};
|
||||
std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1};
|
||||
std::array<ck::index_t, NumDimSpatial> input_right_pads{1, 1};
|
||||
|
||||
SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C);
|
||||
SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C);
|
||||
SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K);
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
ck::Tuple<>,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
ck::Tuple<>,
|
||||
OutDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
std::string best_op_name;
|
||||
int best_op_id = -1;
|
||||
float best_avg_time = std::numeric_limits<float>::max();
|
||||
float best_gb_per_sec = 0;
|
||||
float best_tflops = 0;
|
||||
|
||||
// profile device operation instances
|
||||
std::cout << "Run all instances and do timing" << std::endl;
|
||||
|
||||
for(int i = 0; i < op_ptrs.size(); ++i)
|
||||
{
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{},
|
||||
out.GetDeviceBuffer(),
|
||||
in_lengths,
|
||||
in_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
out_lengths,
|
||||
out_strides,
|
||||
filter_strides,
|
||||
filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
|
||||
|
||||
std::size_t flop = std::size_t(2) * G * N * K * C * Ho * Wo * Y * X;
|
||||
std::size_t num_bytes = sizeof(InDataType) * N * Hi * Wi * G * C +
|
||||
sizeof(WeiDataType) * G * K * Y * X * C +
|
||||
sizeof(OutDataType) * N * Ho * Wo * G * K;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
float gb_per_sec = num_bytes / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_op_id = i;
|
||||
best_op_name = op_name;
|
||||
best_avg_time = avg_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
best_tflops = tflops;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << op_name << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if(best_op_id < 0)
|
||||
{
|
||||
std::cerr << "no suitable instance" << std::endl;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops
|
||||
<< " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
|
||||
|
||||
// run the best intance
|
||||
{
|
||||
auto& op_ptr = op_ptrs[best_op_id];
|
||||
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
|
||||
<< std::endl;
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{},
|
||||
out.GetDeviceBuffer(),
|
||||
in_lengths,
|
||||
in_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
out_lengths,
|
||||
out_strides,
|
||||
filter_strides,
|
||||
filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
|
||||
}
|
||||
|
||||
std::cout << "Done" << std::endl;
|
||||
}
|
||||
return run_grouped_conv_fwd<NumDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
3>({N, Hi, Wi, G, C}, {G, K, Y, X, C}, {N, Ho, Wo, G, K})
|
||||
? EXIT_SUCCESS
|
||||
: EXIT_FAILURE;
|
||||
}
|
||||
|
||||
@@ -7,22 +7,6 @@ endif()
|
||||
if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES)
|
||||
add_executable(client_conv3d_fwd_fp16_comp_fp8 conv3d_fwd_fp16_comp_fp8.cpp)
|
||||
target_link_libraries(client_conv3d_fwd_fp16_comp_fp8 PRIVATE composable_kernel::device_conv_operations)
|
||||
|
||||
add_executable(client_conv3d_fwd_fp8 conv3d_fwd_fp8.cpp)
|
||||
target_link_libraries(client_conv3d_fwd_fp8 PRIVATE composable_kernel::device_conv_operations)
|
||||
endif()
|
||||
|
||||
if((DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES)
|
||||
add_executable(client_conv3d_fwd_bf8 conv3d_fwd_bf8.cpp)
|
||||
target_link_libraries(client_conv3d_fwd_bf8 PRIVATE composable_kernel::device_conv_operations)
|
||||
endif()
|
||||
|
||||
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES)
|
||||
add_executable(client_conv3d_fwd_fp8_bf8 conv3d_fwd_fp8_bf8.cpp)
|
||||
target_link_libraries(client_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations)
|
||||
|
||||
add_executable(client_conv3d_fwd_bf8_fp8 conv3d_fwd_bf8_fp8.cpp)
|
||||
target_link_libraries(client_conv3d_fwd_bf8_fp8 PRIVATE composable_kernel::device_conv_operations)
|
||||
endif()
|
||||
|
||||
if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES)
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
rocm-docs-core==1.1.1
|
||||
rocm-docs-core==1.1.3
|
||||
sphinxcontrib-bibtex==2.6.2
|
||||
|
||||
@@ -103,7 +103,7 @@ requests==2.31.0
|
||||
# via
|
||||
# pygithub
|
||||
# sphinx
|
||||
rocm-docs-core==1.1.1
|
||||
rocm-docs-core==1.1.3
|
||||
# via -r requirements.in
|
||||
six==1.16.0
|
||||
# via
|
||||
|
||||
@@ -44,6 +44,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
if(INSTANCES_ONLY)
|
||||
set(EX_TARGETS ${DEFAULT_GPU_TARGETS})
|
||||
else()
|
||||
set(EX_TARGETS ${GPU_TARGETS})
|
||||
endif()
|
||||
|
||||
#Do not build any DL examples if DL_KERNELS not set
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
|
||||
@@ -53,23 +60,30 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
endforeach()
|
||||
#Do not build any XDL examples if gfx9 targets are not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
|
||||
if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
|
||||
message("removing xdl example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any WMMA examples if gfx11 targets are not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma")
|
||||
if(NOT EX_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma")
|
||||
message("removing wmma example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#only continue if there are some source files left on the list
|
||||
if(FILE_NAME)
|
||||
if(FILE_NAME MATCHES "_xdl")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103)
|
||||
elseif(FILE_NAME MATCHES "_wmma")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
|
||||
endif()
|
||||
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
|
||||
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
|
||||
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
|
||||
add_test(NAME ${EXAMPLE_NAME} COMMAND $<TARGET_FILE:${EXAMPLE_NAME}> ${ARGN})
|
||||
set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} )
|
||||
add_dependencies(examples ${EXAMPLE_NAME})
|
||||
add_dependencies(check ${EXAMPLE_NAME})
|
||||
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
|
||||
@@ -118,6 +132,12 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
if(INSTANCES_ONLY)
|
||||
set(EX_TARGETS ${DEFAULT_GPU_TARGETS})
|
||||
else()
|
||||
set(EX_TARGETS ${GPU_TARGETS})
|
||||
endif()
|
||||
#Do not build any DL examples if DL_KERNELS not set
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
|
||||
@@ -127,23 +147,30 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
|
||||
endforeach()
|
||||
#Do not build any XDL examples if gfx9 targets are not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
|
||||
if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
|
||||
message("removing xdl example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any WMMA examples if gfx11 targets are not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma")
|
||||
if(NOT EX_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma")
|
||||
message("removing wmma example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#only continue if there are some source files left on the list
|
||||
if(FILE_NAME)
|
||||
if(FILE_NAME MATCHES "_xdl")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103)
|
||||
elseif(FILE_NAME MATCHES "_wmma")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
|
||||
endif()
|
||||
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
|
||||
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
|
||||
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
|
||||
add_dependencies(examples ${EXAMPLE_NAME})
|
||||
set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} )
|
||||
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
|
||||
set(result 0)
|
||||
endif()
|
||||
|
||||
@@ -34,6 +34,7 @@ args:
|
||||
if not equal to h, then this is GQA/MQA case
|
||||
-s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328)
|
||||
total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary
|
||||
also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode)
|
||||
-s_k seqlen_k, -1 means equal to s (default:-1)
|
||||
-d head dim for q, k (default:128)
|
||||
-d_v head dim for v, -1 means equal to d (default:-1)
|
||||
|
||||
@@ -44,11 +44,18 @@ auto create_args(int argc, char* argv[])
|
||||
"-1",
|
||||
"num of head, for k/v, -1 means equal to h\n"
|
||||
"if not equal to h, then this is GQA/MQA case")
|
||||
.insert("s",
|
||||
"3328",
|
||||
"seqlen_q. if group-mode, means the average value of seqlen_q\n"
|
||||
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary")
|
||||
.insert(
|
||||
"s",
|
||||
"3328",
|
||||
"seqlen_q. if group-mode, means the average value of seqlen_q\n"
|
||||
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n"
|
||||
"also with \"-s=s0,s1,s2...\" comma seperated int to set per batch seqlen(group-mode)")
|
||||
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
|
||||
.insert("s_kpad",
|
||||
"-1",
|
||||
"seqlen_k stride between 2 tokens, currently used in group-mode only\n"
|
||||
"for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n"
|
||||
"along seqlen, instead of packed. same as xformer kv_padding")
|
||||
.insert("d", "128", "head dim for q, k")
|
||||
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
|
||||
.insert("scale_s",
|
||||
@@ -106,6 +113,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("p_drop", "0", "0~1 probability of dropout")
|
||||
.insert("drop_seed", "1", "seed for random number generator")
|
||||
.insert("drop_offset", "0", "offset for random number generator")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel");
|
||||
|
||||
@@ -165,10 +173,20 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return false;
|
||||
}
|
||||
|
||||
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
|
||||
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
|
||||
if(seqlen_k < 0)
|
||||
seqlen_k = seqlen_q;
|
||||
auto [seqlen_qs, seqlen_ks, seqlen_kpads] = decode_seqlen(mode,
|
||||
batch,
|
||||
arg_parser.get_str("s"),
|
||||
arg_parser.get_str("s_k"),
|
||||
arg_parser.get_str("s_kpad"));
|
||||
|
||||
#if 0
|
||||
// clang-format off
|
||||
std::cout << "seqlen_qs:"; for(auto xx : seqlen_qs) { std::cout << xx << ","; } std::cout << std::endl;
|
||||
std::cout << "seqlen_ks:"; for(auto xx : seqlen_ks) { std::cout << xx << ","; } std::cout << std::endl;
|
||||
std::cout << "seqlen_kpads:"; for(auto xx : seqlen_kpads) { std::cout << xx << ","; } std::cout << std::endl;
|
||||
// clang-format on
|
||||
#endif
|
||||
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
if(hdim_v < 0)
|
||||
@@ -217,7 +235,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
bool lse = arg_parser.get_bool("lse");
|
||||
|
||||
bias_info bias = bias_info::decode(arg_parser.get_str("bias"));
|
||||
mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k);
|
||||
mask_info mask = mask_info::decode(
|
||||
arg_parser.get_str("mask"), seqlen_qs[0], seqlen_ks[0]); // TODO: we don't need x/y anymore
|
||||
|
||||
float p_drop = arg_parser.get_float("p_drop");
|
||||
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
|
||||
@@ -245,11 +264,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
int stream_repeat = arg_parser.get_int("repeat");
|
||||
bool kname = arg_parser.get_bool("kname");
|
||||
|
||||
ck_tile::stream_config stream_config{
|
||||
nullptr, true, /* log_level = */ (kname ? 1 : 0), stream_warmup, stream_repeat};
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
/* log_level = */ (kname ? 1 : 0),
|
||||
stream_warmup,
|
||||
stream_repeat,
|
||||
arg_parser.get_str("timer") == std::string("gpu")};
|
||||
|
||||
const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q);
|
||||
const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k);
|
||||
const auto seqstart_q_host = to_seqstarts(seqlen_qs);
|
||||
const auto seqstart_k_host = to_seqstarts(seqlen_ks);
|
||||
const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads);
|
||||
|
||||
using TypeConfig = FmhaFwdTypeConfig<DataType>;
|
||||
|
||||
@@ -312,9 +336,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// host memory for storing all the tensor elements
|
||||
const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1);
|
||||
const ck_tile::index_t shape_seqlen_q =
|
||||
(mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back());
|
||||
(mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back());
|
||||
const ck_tile::index_t shape_seqlen_k =
|
||||
(mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back());
|
||||
(mode == mode_enum::batch ? seqlen_ks[0]
|
||||
: (seqlen_kpads[0] < 0 ? seqstart_k_host.back()
|
||||
: seqstart_k_with_padding_host.back()));
|
||||
|
||||
ck_tile::HostTensor<QDataType> q_host(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
@@ -421,6 +447,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqlen_k_buf(seqlen_kpads[0] < 0 ? 0 : seqlen_ks.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
|
||||
|
||||
@@ -429,7 +456,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
v_buf.ToDevice(v_host.data());
|
||||
bias_buf.ToDevice(bias_host.data());
|
||||
seqstart_q.ToDevice(seqstart_q_host.data());
|
||||
seqstart_k.ToDevice(seqstart_k_host.data());
|
||||
seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data()
|
||||
: seqstart_k_with_padding_host.data());
|
||||
seqlen_k_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr : seqlen_ks.data());
|
||||
alibi_slope_buf.ToDevice(alibi_slope_host.data());
|
||||
|
||||
// clang-format off
|
||||
@@ -445,7 +474,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const std::string prec = arg_parser.get_str("prec");
|
||||
|
||||
std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch
|
||||
<< ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k
|
||||
<< ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0] << "/" << seqlen_ks[0]
|
||||
<< (seqlen_kpads[0] < 0 ? ""
|
||||
: (std::string("(") + std::to_string(seqlen_kpads[0]) + ")"))
|
||||
<< ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias
|
||||
<< ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant
|
||||
<< ", mask:" << mask << ", v:" << vlayout << std::flush;
|
||||
@@ -476,7 +507,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return ck_tile::identity{};
|
||||
}();
|
||||
|
||||
auto fmha_args = [&]() {
|
||||
auto fmha_args = [&, k_paddings_ = seqlen_kpads]() {
|
||||
assert(nhead % nhead_k == 0);
|
||||
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
|
||||
/// seqlen_k] in this example, hence both the 'batch_stride_bias' &
|
||||
@@ -526,7 +557,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
o_buf.GetDeviceBuffer(),
|
||||
seqstart_q.GetDeviceBuffer(),
|
||||
seqstart_k.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
k_paddings_[0] < 0 ? nullptr : seqlen_k_buf.GetDeviceBuffer(),
|
||||
shape_seqlen_q,
|
||||
shape_seqlen_k,
|
||||
batch,
|
||||
@@ -607,7 +638,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// adjust matrix index according to the mode
|
||||
const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0);
|
||||
const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
|
||||
const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]);
|
||||
const ck_tile::index_t key_offset =
|
||||
(mode == mode_enum::batch
|
||||
? 0
|
||||
: (seqlen_kpads[0] < 0 ? seqstart_k_host[wb] : seqstart_k_with_padding_host[wb]));
|
||||
|
||||
const auto v_host_ref_lengths =
|
||||
std::array<ck_tile::index_t, 3>{nhead, hdim_v, real_seqlen_k};
|
||||
|
||||
@@ -78,6 +78,11 @@ BOOL_MAP = {
|
||||
"f" : "false"
|
||||
}
|
||||
|
||||
TILE_PARTITIONER_MAP = {
|
||||
"shb" : "ck_tile::FmhaFwdTilePartitioner_SHB",
|
||||
"hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS",
|
||||
}
|
||||
|
||||
GEN_DIR = "" # in Cmake, have to generate files in same folder
|
||||
|
||||
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
@@ -138,7 +143,7 @@ using fmha_epilogue_{F_idx} =
|
||||
{F_spad}, {F_dvpad}>>;
|
||||
|
||||
using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner<fmha_shape_{F_idx}>,
|
||||
ck_tile::FmhaFwdKernel<{F_tile_partitioner}<fmha_shape_{F_idx}>,
|
||||
fmha_pipeline_{F_idx},
|
||||
fmha_epilogue_{F_idx}>;
|
||||
|
||||
@@ -156,7 +161,7 @@ float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel<blocks.x, kBlockPerCu>(s, k_{{}}, grids, blocks, 0, kargs);
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
"""
|
||||
|
||||
@@ -394,6 +399,12 @@ class FmhaFwdKernel:
|
||||
F_pipeline : FmhaFwdPipeline
|
||||
mask_impl : str
|
||||
|
||||
def get_tp(self) -> str:
|
||||
if self.F_mode == 'group':
|
||||
return 'hbs'
|
||||
else:
|
||||
return 'shb'
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
kernel_body = str()
|
||||
@@ -418,7 +429,7 @@ class FmhaFwdKernel:
|
||||
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
|
||||
@@ -427,12 +438,13 @@ class FmhaFwdKernel:
|
||||
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag])
|
||||
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_tile_partitioner = TILE_PARTITIONER_MAP[self.get_tp()])
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" +\
|
||||
return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_{self.get_tp()}_" + \
|
||||
self.F_tile.name + '_' + self.F_pipeline.name
|
||||
|
||||
@property
|
||||
|
||||
@@ -29,6 +29,7 @@ $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias
|
||||
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done
|
||||
done
|
||||
|
||||
@@ -4,12 +4,14 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <optional>
|
||||
#include <ostream>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core/container/span.hpp"
|
||||
|
||||
@@ -37,12 +39,14 @@ std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
|
||||
|
||||
std::vector<int32_t> generate_seqlens(mode_enum mode,
|
||||
unsigned count,
|
||||
int32_t seqlens_sum,
|
||||
int32_t seqlen_avg,
|
||||
int32_t seqlen_max = -1, // if not negative, clamp max
|
||||
std::optional<unsigned> seed = std::nullopt)
|
||||
{
|
||||
assert(0 < count);
|
||||
|
||||
std::vector<int32_t> seqlens(count, seqlens_sum);
|
||||
std::vector<int32_t> seqlens(
|
||||
count, seqlen_max > 0 ? (seqlen_avg < seqlen_max ? seqlen_avg : seqlen_max) : seqlen_avg);
|
||||
|
||||
if(mode == mode_enum::group && 1 < count)
|
||||
{
|
||||
@@ -55,7 +59,7 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
|
||||
std::uniform_int_distribution<size_type> step_dist(1, count - 1);
|
||||
auto next_step = std::bind(step_dist, std::ref(random_engine));
|
||||
|
||||
for(unsigned repeat = seqlens_sum * (count / 2); 0 < repeat; --repeat)
|
||||
for(unsigned repeat = seqlen_avg * (count / 2); 0 < repeat; --repeat)
|
||||
{
|
||||
const size_type to_decrease = next_idx();
|
||||
// make sure each elements of seqlens is always greater than 0
|
||||
@@ -66,6 +70,11 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
|
||||
|
||||
const size_type to_increase = (to_decrease + next_step()) % count;
|
||||
|
||||
if(seqlen_max > 0 && seqlens[to_increase] >= seqlen_max)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
--seqlens[to_decrease];
|
||||
++seqlens[to_increase];
|
||||
}
|
||||
@@ -76,10 +85,91 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
|
||||
|
||||
std::vector<int32_t> generate_seqstarts(mode_enum mode,
|
||||
unsigned count,
|
||||
int32_t seqlens_sum,
|
||||
int32_t seqlen_avg,
|
||||
int32_t seqlen_max = -1,
|
||||
std::optional<unsigned> seed = std::nullopt)
|
||||
{
|
||||
return to_seqstarts(generate_seqlens(mode, count, seqlens_sum, seed));
|
||||
return to_seqstarts(generate_seqlens(mode, count, seqlen_avg, seqlen_max, seed));
|
||||
}
|
||||
|
||||
/*
|
||||
* decode the seqlen string from cmdline
|
||||
* example (assume batch=3)
|
||||
* q_val=1,2,3 k_val=4,5,6 -> OK
|
||||
* q_val=1,2,3 -> OK, k same as q
|
||||
* q_val=1,2 -> OK, q will rand remaining 1 element, k same as q
|
||||
* q_val=1,2 k_val=4,5 -> OK, q/k will rand remaining 1 element
|
||||
* q_val=1,2,3,4 -> OK, but ignore exceed one
|
||||
*
|
||||
* q_val=1,2 k_val=4,5,6 -> not OK, k must have same splits with q
|
||||
* q_val=1,2 k_val=4 -> not OK, k must have same splits with q
|
||||
*/
|
||||
std::tuple<std::vector<ck_tile::index_t>,
|
||||
std::vector<ck_tile::index_t>,
|
||||
std::vector<ck_tile::index_t>>
|
||||
decode_seqlen(mode_enum mode,
|
||||
ck_tile::index_t batch,
|
||||
std::string q_val,
|
||||
std::string k_val,
|
||||
std::string k_pad_val,
|
||||
std::optional<unsigned> seed = std::nullopt)
|
||||
{
|
||||
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
|
||||
if(mode == mode_enum::batch)
|
||||
{
|
||||
ck_tile::index_t q = _S2I_(q_val);
|
||||
ck_tile::index_t k = _S2I_(k_val);
|
||||
auto s_q = std::vector<ck_tile::index_t>(batch, q);
|
||||
auto s_k = std::vector<ck_tile::index_t>(batch, k < 0 ? q : k);
|
||||
auto s_kpad = std::vector<ck_tile::index_t>(batch, -1); // TODO: batch not support k_padding
|
||||
return std::make_tuple(s_q, s_k, s_kpad);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::index_t idx = 0;
|
||||
std::string::size_type pos_q = 0;
|
||||
std::string::size_type pos_k = 0;
|
||||
std::string::size_type pos_kp = 0;
|
||||
std::vector<ck_tile::index_t> s_q;
|
||||
std::vector<ck_tile::index_t> s_k;
|
||||
std::vector<ck_tile::index_t> s_kpad;
|
||||
while(true)
|
||||
{
|
||||
auto found_q = q_val.find(',', pos_q);
|
||||
auto found_k = k_val.find(',', pos_k);
|
||||
auto found_kp = k_pad_val.find(',', pos_kp);
|
||||
|
||||
ck_tile::index_t q = _S2I_(
|
||||
q_val.substr(pos_q, found_q == std::string::npos ? found_q : found_q - pos_q));
|
||||
ck_tile::index_t k = _S2I_(
|
||||
k_val.substr(pos_k, found_k == std::string::npos ? found_k : found_k - pos_k));
|
||||
ck_tile::index_t kp = _S2I_(k_pad_val.substr(
|
||||
pos_kp, found_kp == std::string::npos ? found_kp : found_kp - pos_kp));
|
||||
|
||||
s_q.push_back(q);
|
||||
s_k.push_back(k < 0 ? q : k);
|
||||
s_kpad.push_back(kp);
|
||||
idx++;
|
||||
if(found_q == std::string::npos || idx >= batch)
|
||||
{
|
||||
break;
|
||||
}
|
||||
pos_q = found_q + 1;
|
||||
pos_k = found_k == std::string::npos ? pos_k : found_k + 1;
|
||||
pos_kp = found_kp == std::string::npos ? pos_kp : found_kp + 1;
|
||||
}
|
||||
if(idx < batch)
|
||||
{
|
||||
auto rem_q = generate_seqlens(mode, batch - idx, s_q.back(), s_kpad.back(), seed);
|
||||
auto rem_k = generate_seqlens(mode, batch - idx, s_k.back(), s_kpad.back(), seed);
|
||||
|
||||
s_q.insert(s_q.end(), rem_q.begin(), rem_q.end());
|
||||
s_k.insert(s_k.end(), rem_k.begin(), rem_k.end());
|
||||
s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back());
|
||||
}
|
||||
return std::make_tuple(s_q, s_k, s_kpad);
|
||||
}
|
||||
#undef _S2I_
|
||||
}
|
||||
|
||||
int env_get_int(const char* var_name, int default_int)
|
||||
@@ -87,6 +177,6 @@ int env_get_int(const char* var_name, int default_int)
|
||||
char* v = getenv(var_name);
|
||||
int r = default_int;
|
||||
if(v)
|
||||
r = atoi(v);
|
||||
r = std::atoi(v);
|
||||
return r;
|
||||
}
|
||||
|
||||
@@ -104,20 +104,25 @@ inline void flush_icache()
|
||||
hip_check_error(hipGetLastError());
|
||||
}
|
||||
// if TimePrePress == false, return time does not include preprocess's time
|
||||
template <bool TimePreprocess, typename Args, typename F, typename PreProcessFunc>
|
||||
template <bool TimePreprocess,
|
||||
typename GemmArgs,
|
||||
typename... Args,
|
||||
typename F,
|
||||
typename PreProcessFunc>
|
||||
float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
|
||||
PreProcessFunc preprocess,
|
||||
F kernel,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t lds_byte,
|
||||
Args& args)
|
||||
GemmArgs& gemm_args,
|
||||
Args... args)
|
||||
{
|
||||
#if CK_TIME_KERNEL
|
||||
#define MEDIAN 1
|
||||
if(stream_config.time_kernel_)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n",
|
||||
__func__,
|
||||
@@ -133,7 +138,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
|
||||
// warm up
|
||||
for(int i = 0; i < stream_config.cold_niters_; ++i)
|
||||
{
|
||||
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args);
|
||||
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
|
||||
hip_check_error(hipGetLastError());
|
||||
}
|
||||
|
||||
@@ -142,7 +147,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
|
||||
{
|
||||
return 0.0;
|
||||
}
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
printf("Start running %d times...\n", nrepeat);
|
||||
}
|
||||
@@ -172,7 +177,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
|
||||
preprocess();
|
||||
}
|
||||
// run real kernel
|
||||
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args);
|
||||
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
|
||||
hip_check_error(hipGetLastError());
|
||||
// end real kernel
|
||||
|
||||
@@ -186,13 +191,13 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
|
||||
total_time += cur_time;
|
||||
#endif
|
||||
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "i: " << i << " cur_time: " << cur_time << std::endl;
|
||||
|
||||
printf("args.p_a_grid: %p, args.p_b_grid:%p\n",
|
||||
static_cast<const void*>(args.p_a_grid),
|
||||
static_cast<const void*>(args.p_b_grid));
|
||||
printf("gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p\n",
|
||||
static_cast<const void*>(gemm_args.p_a_grid),
|
||||
static_cast<const void*>(gemm_args.p_b_grid));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -216,13 +221,13 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
|
||||
else
|
||||
{
|
||||
preprocess();
|
||||
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args);
|
||||
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
|
||||
hip_check_error(hipGetLastError());
|
||||
|
||||
return 0;
|
||||
}
|
||||
#else
|
||||
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args);
|
||||
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
|
||||
hip_check_error(hipGetLastError());
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -20,7 +20,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
|
||||
#if CK_TIME_KERNEL
|
||||
if(stream_config.time_kernel_)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n",
|
||||
__func__,
|
||||
@@ -41,7 +41,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
|
||||
}
|
||||
|
||||
const int nrepeat = stream_config.nrepeat_;
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
printf("Start running %d times...\n", nrepeat);
|
||||
}
|
||||
@@ -95,7 +95,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
|
||||
#if CK_TIME_KERNEL
|
||||
if(stream_config.time_kernel_)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n",
|
||||
__func__,
|
||||
@@ -117,7 +117,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
|
||||
}
|
||||
|
||||
const int nrepeat = stream_config.nrepeat_;
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
printf("Start running %d times...\n", nrepeat);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -1952,7 +1952,7 @@ struct Modulo
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLengths>
|
||||
template <typename LowLengths, bool ApplyModulo>
|
||||
struct Xor
|
||||
{
|
||||
using LowerIndex = MultiIndex<2>;
|
||||
@@ -1981,8 +1981,15 @@ struct Xor
|
||||
|
||||
idx_low(Number<0>{}) = idx_up[Number<0>{}];
|
||||
|
||||
idx_low(Number<1>{}) =
|
||||
idx_up[Number<1>{}] ^ (idx_up[Number<0>{}] % up_lengths_[Number<1>{}]);
|
||||
if constexpr(ApplyModulo)
|
||||
{
|
||||
idx_low(Number<1>{}) =
|
||||
idx_up[Number<1>{}] ^ (idx_up[Number<0>{}] % up_lengths_[Number<1>{}]);
|
||||
}
|
||||
else
|
||||
{
|
||||
idx_low(Number<1>{}) = idx_up[Number<1>{}] ^ idx_up[Number<0>{}];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LowIdxDiff,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -128,9 +128,15 @@ __host__ __device__ constexpr auto make_modulo_transform(const Modulus& modulus,
|
||||
return Modulo<Modulus, UpLength>{modulus, up_length};
|
||||
}
|
||||
|
||||
template <typename LowLengths>
|
||||
__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths& low_lengths)
|
||||
{
|
||||
return Xor<LowLengths, true /*ApplyModulo*/>{low_lengths};
|
||||
}
|
||||
|
||||
template <typename LowLengths>
|
||||
__host__ __device__ constexpr auto make_xor_transform(const LowLengths& low_lengths)
|
||||
{
|
||||
return Xor<LowLengths>{low_lengths};
|
||||
return Xor<LowLengths, false /*ApplyModulo*/>{low_lengths};
|
||||
}
|
||||
} // namespace ck
|
||||
|
||||
@@ -795,11 +795,6 @@ struct BlockwiseGemmXdlops_v2
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
__host__ __device__ BlockwiseGemmXdlops_v2(const BlockwiseGemmXdlops_v2& other)
|
||||
: a_thread_copy_(other.a_origin), b_thread_copy_(other.b_origin)
|
||||
{
|
||||
}
|
||||
|
||||
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
|
||||
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
|
||||
{
|
||||
|
||||
@@ -587,7 +587,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
BatchStrideD1s,
|
||||
BatchStrideE1}
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "a0_grid_desc_m_k_{" << a0_grid_desc_m_k_.GetLength(I0) << ", "
|
||||
<< a0_grid_desc_m_k_.GetLength(I1) << "}" << std::endl;
|
||||
|
||||
@@ -658,7 +658,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
{
|
||||
std::cout << "arg.Batch_ = " << arg.Batch_ << std::endl;
|
||||
|
||||
@@ -719,7 +719,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
@@ -53,8 +53,7 @@ __global__ void
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
|
||||
@@ -516,7 +516,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
float ave_time = 0;
|
||||
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_container_{"
|
||||
|
||||
@@ -644,7 +644,7 @@ struct
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << DeviceOp{}.GetTypeString() << std::endl;
|
||||
std::cout << "N " << arg.Conv_N_ << ", "
|
||||
|
||||
@@ -614,7 +614,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << DeviceOp{}.GetTypeString() << std::endl;
|
||||
std::cout << "N " << arg.Conv_N_ << ", "
|
||||
|
||||
@@ -579,7 +579,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << DeviceOp{}.GetTypeString() << std::endl;
|
||||
std::cout << "N " << arg.Conv_N_ << ", "
|
||||
|
||||
@@ -431,7 +431,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
|
||||
@@ -401,7 +401,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "num_batches_of_GEMM = " << arg.num_subbatches_ << std::endl;
|
||||
std::cout << "a_grid_desc_k0_m_k1{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
|
||||
@@ -1272,7 +1272,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
|
||||
float ave_time = 0;
|
||||
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_container_{"
|
||||
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
|
||||
|
||||
@@ -1220,7 +1220,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
|
||||
float ave_time = 0;
|
||||
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1{"
|
||||
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
|
||||
|
||||
@@ -334,7 +334,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{"
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
|
||||
|
||||
@@ -510,7 +510,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
|
||||
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
|
||||
|
||||
@@ -514,7 +514,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
|
||||
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
|
||||
|
||||
@@ -299,7 +299,7 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm<ALayout,
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -45,8 +45,7 @@ __global__ void
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
const index_t KBatch = 1;
|
||||
|
||||
@@ -553,7 +553,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
|
||||
|
||||
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{"
|
||||
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
|
||||
@@ -337,6 +337,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
elementwise_d_grid_descs_m_n_.reserve(group_count_);
|
||||
ds_grid_pointer_.reserve(group_count_);
|
||||
group_grid_size_.reserve(group_count_);
|
||||
e_ptrs_.reserve(group_count_);
|
||||
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); ++i)
|
||||
{
|
||||
@@ -380,7 +381,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
const index_t block_end = grid_size_ + grid_size_grp;
|
||||
|
||||
grid_size_ += grid_size_grp;
|
||||
group_grid_size_[i] = grid_size_grp;
|
||||
group_grid_size_.push_back(grid_size_grp);
|
||||
// block-to-e-tile map
|
||||
auto grouped_block_2_ctile_map =
|
||||
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
|
||||
@@ -421,9 +422,9 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
elementwise_c_grid_descs_m_n_.push_back(c_grid_desc_m_n);
|
||||
elementwise_d_grid_descs_m_n_.push_back(ds_grid_desc_m_n);
|
||||
ds_grid_pointer_.push_back(p_ds_grid);
|
||||
// Store a copy of E pointers for elementwise kernel destination
|
||||
e_ptrs_.push_back(p_Es[i]);
|
||||
}
|
||||
// Store a copy of E pointers for elementwise kernel destination
|
||||
e_ptrs_ = p_Es;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -467,7 +468,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
gemm_kernel_args_[i].block_start_ = block_start;
|
||||
gemm_kernel_args_[i].block_end_ = block_end;
|
||||
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
index_t tiles = (block_end - block_start) / K_BATCH;
|
||||
std::cout << "block_start: " << block_start << "\n"
|
||||
@@ -494,7 +495,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
arg.karg_.p_c_grid = p_workspace + offset;
|
||||
index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch;
|
||||
offset += tiles * MPerBlock * NPerBlock;
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "block_start: " << arg.block_start_ << "\n"
|
||||
<< "block_end: " << arg.block_end_ << "\n"
|
||||
@@ -774,13 +775,13 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
cast_pointer_to_constant_address_space(dev_gemm_args),
|
||||
arg.group_count_,
|
||||
arg.gemm_kernel_args_.size(),
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
PassThrough{});
|
||||
|
||||
// Elementwise kernels
|
||||
for(int i = 0; i < arg.group_count_; ++i)
|
||||
for(size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
|
||||
{
|
||||
time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
@@ -818,7 +819,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) +
|
||||
arg.skipped_group_count_) != arg.group_count_)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "The group count is not equal to sum of skipped groups "
|
||||
"and kernel args size!"
|
||||
@@ -835,7 +836,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
bool group_arg_valid = GridwiseGemm::CheckValidity(gemm_arg);
|
||||
if(not group_arg_valid)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "[" << __func__ << "] group id: " << i
|
||||
<< " has invalid GridwiseGemm settings!" << std::endl;
|
||||
|
||||
@@ -620,7 +620,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
|
||||
GridwiseGemm::template CheckTensorTransfersValidity<ALayout, BLayout, ELayout>(
|
||||
M, N, K)))
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "The provided GEMM problem size (M,N,K) [" << M << "," << N << ","
|
||||
<< K << "] are not supported by current template parameters!"
|
||||
|
||||
@@ -514,7 +514,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
|
||||
|
||||
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "group: " << i << " arg.a_grid_desc_ak0_m_ak1_{"
|
||||
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0)
|
||||
|
||||
@@ -529,7 +529,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) +
|
||||
arg.skipped_group_count_) != arg.group_count_)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "The group count is not equal to sum of skipped groups "
|
||||
"and kernel args size!"
|
||||
@@ -545,7 +545,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
|
||||
bool group_arg_valid = GridwiseGemm::CheckValidity(a);
|
||||
if(not group_arg_valid)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "[" << __func__ << "] group id: " << i
|
||||
<< " has invalid GridwiseGemm settings!" << std::endl;
|
||||
|
||||
@@ -50,8 +50,7 @@ __global__ void
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
|
||||
defined(__gfx1102__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
__shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
@@ -80,7 +79,7 @@ __global__ void
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
ignore = block_2_ctile_map;
|
||||
#endif // end of if (defined(__gfx1100__))
|
||||
#endif // end of if (defined(__gfx11__))
|
||||
}
|
||||
|
||||
// Assume B is Col-Major
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -34,8 +34,7 @@ __global__ void
|
||||
// __attribute__((amdgpu_waves_per_eu(1, 1)))
|
||||
kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
|
||||
@@ -48,7 +47,7 @@ __global__ void
|
||||
karg);
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
}
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
@@ -63,8 +62,7 @@ __global__ void
|
||||
// __attribute__((amdgpu_waves_per_eu(1, 1)))
|
||||
kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
// Pass two lds pointer is the key to tell compiler that ds_read/write
|
||||
// operate on different lds chunk at same time without order dependecy
|
||||
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
@@ -81,7 +79,7 @@ __global__ void
|
||||
karg);
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
}
|
||||
|
||||
template <typename ALayout,
|
||||
@@ -605,8 +603,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc,
|
||||
make_tuple(make_xor_transform(make_tuple(Number<MPerBlock / MLdsLayer>{},
|
||||
Number<AK0Number * MLdsLayer>{})),
|
||||
make_tuple(make_xor_with_modulo_transform(make_tuple(
|
||||
Number<MPerBlock / MLdsLayer>{}, Number<AK0Number * MLdsLayer>{})),
|
||||
make_pass_through_transform(AK1Number)),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
|
||||
@@ -671,7 +669,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(Number<K0PerThreadWrite>{}),
|
||||
make_xor_transform(
|
||||
make_xor_with_modulo_transform(
|
||||
make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
|
||||
make_pass_through_transform(Number<mpair>{}),
|
||||
make_pass_through_transform(AK1Number)),
|
||||
@@ -742,8 +740,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc,
|
||||
make_tuple(make_xor_transform(make_tuple(Number<NPerBlock / NLdsLayer>{},
|
||||
Number<BK0Number * NLdsLayer>{})),
|
||||
make_tuple(make_xor_with_modulo_transform(make_tuple(
|
||||
Number<NPerBlock / NLdsLayer>{}, Number<BK0Number * NLdsLayer>{})),
|
||||
make_pass_through_transform(BK1Number)),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
|
||||
@@ -805,7 +803,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(Number<K0PerThreadWrite>{}),
|
||||
make_xor_transform(
|
||||
make_xor_with_modulo_transform(
|
||||
make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
|
||||
make_pass_through_transform(Number<npair>{}),
|
||||
make_pass_through_transform(BK1Number)),
|
||||
@@ -935,7 +933,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(!(karg.M % MPerBlock == 0))
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
@@ -952,7 +950,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(!(karg.N % NPerBlock == 0))
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
@@ -971,7 +969,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
auto K_t = karg.KBatch * KPerBlock;
|
||||
if(!(karg.K % K_t == 0))
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
|
||||
<< karg.K << " " << __FILE__ << ":" << __LINE__
|
||||
@@ -995,7 +993,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(karg.K % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg K (" << karg.K
|
||||
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
|
||||
@@ -1009,7 +1007,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(karg.M % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg M (" << karg.M
|
||||
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
|
||||
@@ -1024,7 +1022,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(karg.N % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg N (" << karg.N
|
||||
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
|
||||
@@ -1038,7 +1036,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(karg.K % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg K (" << karg.K
|
||||
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
|
||||
@@ -1053,7 +1051,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg N (" << karg.N
|
||||
<< ") value is not a multiple of "
|
||||
@@ -1069,7 +1067,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg M (" << karg.M
|
||||
<< ") value is not a multiple of "
|
||||
@@ -1084,7 +1082,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
if constexpr(is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__
|
||||
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -33,8 +33,7 @@ __global__ void
|
||||
// __attribute__((amdgpu_waves_per_eu(1, 1)))
|
||||
kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
@@ -49,7 +48,7 @@ __global__ void
|
||||
karg.c_element_op);
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
}
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
@@ -64,8 +63,7 @@ __global__ void
|
||||
// __attribute__((amdgpu_waves_per_eu(1, 1)))
|
||||
kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
// Pass two lds pointer is the key to tell compiler that ds_read/write
|
||||
// operate on different lds chunk at same time without order dependecy
|
||||
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
@@ -84,7 +82,7 @@ __global__ void
|
||||
karg.c_element_op);
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
}
|
||||
|
||||
template <typename ALayout,
|
||||
@@ -783,8 +781,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc,
|
||||
make_tuple(make_xor_transform(make_tuple(Number<MPerBlock / MLdsLayer>{},
|
||||
Number<AK0Number * MLdsLayer>{})),
|
||||
make_tuple(make_xor_with_modulo_transform(make_tuple(
|
||||
Number<MPerBlock / MLdsLayer>{}, Number<AK0Number * MLdsLayer>{})),
|
||||
make_pass_through_transform(AK1Number)),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
|
||||
@@ -849,7 +847,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(Number<K0PerThreadWrite>{}),
|
||||
make_xor_transform(
|
||||
make_xor_with_modulo_transform(
|
||||
make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
|
||||
make_pass_through_transform(Number<mpair>{}),
|
||||
make_pass_through_transform(AK1Number)),
|
||||
@@ -920,8 +918,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc,
|
||||
make_tuple(make_xor_transform(make_tuple(Number<NPerBlock / NLdsLayer>{},
|
||||
Number<BK0Number * NLdsLayer>{})),
|
||||
make_tuple(make_xor_with_modulo_transform(make_tuple(
|
||||
Number<NPerBlock / NLdsLayer>{}, Number<BK0Number * NLdsLayer>{})),
|
||||
make_pass_through_transform(BK1Number)),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
|
||||
@@ -983,7 +981,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(Number<K0PerThreadWrite>{}),
|
||||
make_xor_transform(
|
||||
make_xor_with_modulo_transform(
|
||||
make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
|
||||
make_pass_through_transform(Number<npair>{}),
|
||||
make_pass_through_transform(BK1Number)),
|
||||
@@ -1113,7 +1111,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(!(karg.M % MPerBlock == 0))
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
@@ -1130,7 +1128,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(!(karg.N % NPerBlock == 0))
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
@@ -1149,7 +1147,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
auto K_t = karg.KBatch * KPerBlock;
|
||||
if(!(karg.K % K_t == 0))
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
|
||||
<< karg.K << " " << __FILE__ << ":" << __LINE__
|
||||
@@ -1173,7 +1171,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(karg.K % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg K (" << karg.K
|
||||
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
|
||||
@@ -1187,7 +1185,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(karg.M % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg M (" << karg.M
|
||||
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
|
||||
@@ -1202,7 +1200,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(karg.N % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg N (" << karg.N
|
||||
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
|
||||
@@ -1216,7 +1214,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(karg.K % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg K (" << karg.K
|
||||
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
|
||||
@@ -1231,7 +1229,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg N (" << karg.N
|
||||
<< ") value is not a multiple of "
|
||||
@@ -1247,7 +1245,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
{
|
||||
if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg M (" << karg.M
|
||||
<< ") value is not a multiple of "
|
||||
|
||||
@@ -38,8 +38,7 @@ __global__ void
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
|
||||
|
||||
__shared__ uint8_t p_shared[shared_size];
|
||||
@@ -52,7 +51,7 @@ __global__ void
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
#endif // end of if (defined(__gfx9__))
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
|
||||
@@ -446,7 +446,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
{
|
||||
if(!(karg.M % MPerBlock == 0))
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
@@ -463,7 +463,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
{
|
||||
if(!(karg.N % NPerBlock == 0))
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
@@ -482,7 +482,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
auto K_t = karg.k_batch * K0PerBlock * K1;
|
||||
if(!(karg.K % K_t == 0))
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
|
||||
<< karg.K << " " << __FILE__ << ":" << __LINE__
|
||||
@@ -496,7 +496,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
{
|
||||
if(karg.K % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg K (" << karg.K
|
||||
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
|
||||
@@ -510,7 +510,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
{
|
||||
if(karg.M % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg M (" << karg.M
|
||||
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
|
||||
@@ -525,7 +525,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
{
|
||||
if(karg.N % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg N (" << karg.N
|
||||
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
|
||||
@@ -539,7 +539,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
{
|
||||
if(karg.K % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg K (" << karg.K
|
||||
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
|
||||
@@ -554,7 +554,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
{
|
||||
if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg N (" << karg.N
|
||||
<< ") value is not a multiple of "
|
||||
@@ -569,7 +569,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
{
|
||||
if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg M (" << karg.M
|
||||
<< ") value is not a multiple of "
|
||||
@@ -584,7 +584,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
|
||||
const auto num_k_loop = karg.K0Padded / K0PerBlock;
|
||||
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
|
||||
{
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "The number of k loops (" << num_k_loop
|
||||
<< ") value is not supported by GridwiseGemm Pipeline."
|
||||
|
||||
@@ -0,0 +1,640 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/library/utility/numeric.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
/**
|
||||
* @brief Transform conv bwd weight to gemm v2
|
||||
*
|
||||
* This version does following things:
|
||||
* 1. Merge KBatch with K0 to align descriptor with universal gemm
|
||||
* 2. Merge Batch with M and N dimension. It allows to increase compute in
|
||||
* case of small M and N. It also allows to vector load and store in case of
|
||||
* K = 1, C = 1 and NHWGC layout.
|
||||
*/
|
||||
template <index_t NDimSpatial,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t GemmK1Number,
|
||||
index_t K0PerBlock,
|
||||
index_t NumBatchToMerge,
|
||||
device::ConvolutionBackwardWeightSpecialization ConvBackwardWeightSpecialization>
|
||||
struct TransformConvBwdWeightToGemmV2
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
template <index_t NDim, typename enable_if<NDim == 2, bool>::type = false>
|
||||
constexpr static auto
|
||||
make_out_grid_desc(const index_t N,
|
||||
const index_t Ho,
|
||||
const index_t Wo,
|
||||
const index_t K,
|
||||
const std::array<index_t, NDimSpatial + 3>& output_strides)
|
||||
{
|
||||
const index_t BatchStride = output_strides[0];
|
||||
const index_t WoStride = output_strides[4];
|
||||
const auto KStride = Number<1>{};
|
||||
return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, NumBatchToMerge, K),
|
||||
make_tuple(WoStride, BatchStride, KStride));
|
||||
}
|
||||
|
||||
template <index_t NDim, typename enable_if<NDim == 2, bool>::type = false>
|
||||
constexpr static auto
|
||||
make_in_grid_desc(const index_t N,
|
||||
const index_t Hi,
|
||||
const index_t Wi,
|
||||
const index_t C,
|
||||
const std::array<index_t, NDimSpatial + 3>& input_strides)
|
||||
{
|
||||
const index_t BatchStride = input_strides[0];
|
||||
const index_t NStride = input_strides[1];
|
||||
const index_t HiStride = input_strides[3];
|
||||
const index_t WiStride = input_strides[4];
|
||||
const auto CStride = input_strides[2];
|
||||
if constexpr(ConvBackwardWeightSpecialization ==
|
||||
device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, NumBatchToMerge, C),
|
||||
make_tuple(WiStride, BatchStride, CStride));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N, Hi, Wi, NumBatchToMerge, C),
|
||||
make_tuple(NStride, HiStride, WiStride, BatchStride, CStride));
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t NDim, typename enable_if<NDim == 2, bool>::type = false>
|
||||
constexpr static auto
|
||||
make_wei_grid_desc(const index_t K,
|
||||
const index_t Y,
|
||||
const index_t X,
|
||||
const index_t C,
|
||||
const std::array<index_t, NDimSpatial + 3>& weights_strides)
|
||||
{
|
||||
const auto CStride = Number<1>{};
|
||||
const auto KStride = weights_strides[1];
|
||||
const auto XStride = weights_strides[4];
|
||||
const auto BatchStride = weights_strides[0];
|
||||
// Add NumBatchToMerge for Batch+M dimension and, 1 as a placehorder
|
||||
// for Batch+N dimension
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(NumBatchToMerge, K, Y * X, 1, C),
|
||||
make_tuple(BatchStride, KStride, XStride, BatchStride, CStride));
|
||||
// Padd 1 to NumBatchToMerge
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_pass_through_transform(NumBatchToMerge),
|
||||
make_pass_through_transform(K),
|
||||
make_pass_through_transform(Y * X),
|
||||
make_pad_transform(1, 0, NumBatchToMerge - 1),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
// We need only matrices from diagonal. Xor returns 0 for the same
|
||||
// values. So if matrices is not on diagonal then it will be stored in padding.
|
||||
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
|
||||
static_assert(NumBatchToMerge == 1 || NumBatchToMerge == 2 || NumBatchToMerge == 4 ||
|
||||
NumBatchToMerge == 8 || NumBatchToMerge == 16 || NumBatchToMerge == 32 ||
|
||||
NumBatchToMerge == 64);
|
||||
const auto unmerged_padded_desc = transform_tensor_descriptor(
|
||||
padded_desc,
|
||||
make_tuple(make_xor_transform(make_tuple(NumBatchToMerge, NumBatchToMerge)),
|
||||
make_pass_through_transform(K),
|
||||
make_pass_through_transform(Y * X),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}));
|
||||
// Merge To M, N
|
||||
return transform_tensor_descriptor(
|
||||
unmerged_padded_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(NumBatchToMerge, K)),
|
||||
make_merge_transform(make_tuple(Y * X, NumBatchToMerge, C))),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
template <index_t NDim, typename enable_if<NDim == 3, bool>::type = false>
|
||||
constexpr static auto
|
||||
make_out_grid_desc(const index_t N,
|
||||
const index_t Do,
|
||||
const index_t Ho,
|
||||
const index_t Wo,
|
||||
const index_t K,
|
||||
const std::array<index_t, NDimSpatial + 3>& output_strides)
|
||||
{
|
||||
const index_t BatchStride = output_strides[0];
|
||||
const index_t WoStride = output_strides[5];
|
||||
const auto KStride = Number<1>{};
|
||||
return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, NumBatchToMerge, K),
|
||||
make_tuple(WoStride, BatchStride, KStride));
|
||||
}
|
||||
|
||||
template <index_t NDim, typename enable_if<NDim == 3, bool>::type = false>
|
||||
constexpr static auto
|
||||
make_in_grid_desc(const index_t N,
|
||||
const index_t Di,
|
||||
const index_t Hi,
|
||||
const index_t Wi,
|
||||
const index_t C,
|
||||
const std::array<index_t, NDimSpatial + 3>& input_strides)
|
||||
{
|
||||
const index_t BatchStride = input_strides[0];
|
||||
const index_t NStride = input_strides[1];
|
||||
const index_t DiStride = input_strides[3];
|
||||
const index_t HiStride = input_strides[4];
|
||||
const index_t WiStride = input_strides[5];
|
||||
const auto CStride = input_strides[2];
|
||||
if constexpr(ConvBackwardWeightSpecialization ==
|
||||
device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, NumBatchToMerge, C),
|
||||
make_tuple(WiStride, BatchStride, CStride));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N, Di, Hi, Wi, NumBatchToMerge, C),
|
||||
make_tuple(NStride, DiStride, HiStride, WiStride, BatchStride, CStride));
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t NDim, typename enable_if<NDim == 3, bool>::type = false>
|
||||
constexpr static auto
|
||||
make_wei_grid_desc(const index_t K,
|
||||
const index_t Z,
|
||||
const index_t Y,
|
||||
const index_t X,
|
||||
const index_t C,
|
||||
const std::array<index_t, NDimSpatial + 3>& weights_strides)
|
||||
{
|
||||
const auto CStride = Number<1>{};
|
||||
const auto KStride = weights_strides[1];
|
||||
const auto XStride = weights_strides[5];
|
||||
const auto BatchStride = weights_strides[0];
|
||||
// Add NumBatchToMerge for Batch+M dimension and, 1 for placehord for Batch+N dimension
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(NumBatchToMerge, K, Z * Y * X, 1, C),
|
||||
make_tuple(BatchStride, KStride, XStride, BatchStride, CStride));
|
||||
// Padd 1 to NumBatchToMerge
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_pass_through_transform(NumBatchToMerge),
|
||||
make_pass_through_transform(K),
|
||||
make_pass_through_transform(Z * Y * X),
|
||||
make_pad_transform(1, 0, NumBatchToMerge - 1),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
// We need only matrices from diagonal. Xor returns 0 for the same
|
||||
// values. So if matrices is not on diagonal then it will be stored in padding.
|
||||
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
|
||||
static_assert(NumBatchToMerge == 1 || NumBatchToMerge == 2 || NumBatchToMerge == 4 ||
|
||||
NumBatchToMerge == 8 || NumBatchToMerge == 16 || NumBatchToMerge == 32 ||
|
||||
NumBatchToMerge == 64);
|
||||
const auto unmerged_padded_desc = transform_tensor_descriptor(
|
||||
padded_desc,
|
||||
make_tuple(make_xor_transform(make_tuple(NumBatchToMerge, NumBatchToMerge)),
|
||||
make_pass_through_transform(K),
|
||||
make_pass_through_transform(Z * Y * X),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}));
|
||||
// Merge To M, N
|
||||
return transform_tensor_descriptor(
|
||||
unmerged_padded_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(NumBatchToMerge, K)),
|
||||
make_merge_transform(make_tuple(Z * Y * X, NumBatchToMerge, C))),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
template <index_t NDim, typename enable_if<NDim == 2, bool>::type = false>
|
||||
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
const index_t N,
|
||||
const index_t K,
|
||||
const index_t C,
|
||||
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& input_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& weights_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& output_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const index_t batch_k)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
const index_t Hi = input_spatial_lengths[0];
|
||||
const index_t Wi = input_spatial_lengths[1];
|
||||
|
||||
const index_t Ho = output_spatial_lengths[0];
|
||||
const index_t Wo = output_spatial_lengths[1];
|
||||
|
||||
const index_t Y = filter_spatial_lengths[0];
|
||||
const index_t X = filter_spatial_lengths[1];
|
||||
|
||||
const index_t ConvStrideH = conv_filter_strides[0];
|
||||
const index_t ConvStrideW = conv_filter_strides[1];
|
||||
|
||||
const index_t ConvDilationH = conv_filter_dilations[0];
|
||||
const index_t ConvDilationW = conv_filter_dilations[1];
|
||||
|
||||
const index_t InLeftPadH = input_left_pads[0];
|
||||
const index_t InLeftPadW = input_left_pads[1];
|
||||
|
||||
const index_t InRightPadH = input_right_pads[0];
|
||||
const index_t InRightPadW = input_right_pads[1];
|
||||
|
||||
const index_t GemmKTotal = N * Ho * Wo;
|
||||
const index_t GemmM = K * NumBatchToMerge;
|
||||
const index_t GemmN = C * X * Y * NumBatchToMerge;
|
||||
|
||||
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
|
||||
|
||||
const index_t GemmKBatch = batch_k;
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
|
||||
K0PerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
|
||||
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Ho, Wo, K, output_strides);
|
||||
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Hi, Wi, C, input_strides);
|
||||
const auto wei_grid_desc = make_wei_grid_desc<NDim>(K, Y, X, C, weights_strides);
|
||||
|
||||
if constexpr(ConvBackwardWeightSpecialization ==
|
||||
device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// A: output tensor
|
||||
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
make_tuple(
|
||||
make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_merge_transform(make_tuple(NumBatchToMerge, GemmM / NumBatchToMerge))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// B: input tensor
|
||||
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(
|
||||
make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_merge_transform(make_tuple(NumBatchToMerge, GemmN / NumBatchToMerge))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
// A: output tensor
|
||||
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
make_tuple(
|
||||
make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_merge_transform(make_tuple(NumBatchToMerge, GemmM / NumBatchToMerge))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// B: input tensor
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(NumBatchToMerge),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(NumBatchToMerge),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1, 2>{},
|
||||
Sequence<3, 4>{},
|
||||
Sequence<5>{},
|
||||
Sequence<6>{}));
|
||||
|
||||
const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, NumBatchToMerge, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// Padd
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmKBatch * GemmK0),
|
||||
make_right_pad_transform(GemmM, PadGemmM),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmKBatch * GemmK0),
|
||||
make_right_pad_transform(GemmN, PadGemmN),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto wei_gemmm_gemmn_pad_grid_desc =
|
||||
transform_tensor_descriptor(wei_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
|
||||
wei_gemmm_gemmn_pad_grid_desc);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t NDim, typename enable_if<NDim == 3, bool>::type = false>
|
||||
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
const index_t N,
|
||||
const index_t K,
|
||||
const index_t C,
|
||||
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& input_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& weights_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& output_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const index_t batch_k)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
const index_t Di = input_spatial_lengths[0];
|
||||
const index_t Hi = input_spatial_lengths[1];
|
||||
const index_t Wi = input_spatial_lengths[2];
|
||||
|
||||
const index_t Do = output_spatial_lengths[0];
|
||||
const index_t Ho = output_spatial_lengths[1];
|
||||
const index_t Wo = output_spatial_lengths[2];
|
||||
|
||||
const index_t Z = filter_spatial_lengths[0];
|
||||
const index_t Y = filter_spatial_lengths[1];
|
||||
const index_t X = filter_spatial_lengths[2];
|
||||
|
||||
const index_t ConvStrideD = conv_filter_strides[0];
|
||||
const index_t ConvStrideH = conv_filter_strides[1];
|
||||
const index_t ConvStrideW = conv_filter_strides[2];
|
||||
|
||||
const index_t ConvDilationD = conv_filter_dilations[0];
|
||||
const index_t ConvDilationH = conv_filter_dilations[1];
|
||||
const index_t ConvDilationW = conv_filter_dilations[2];
|
||||
|
||||
const index_t InLeftPadD = input_left_pads[0];
|
||||
const index_t InLeftPadH = input_left_pads[1];
|
||||
const index_t InLeftPadW = input_left_pads[2];
|
||||
|
||||
const index_t InRightPadD = input_right_pads[0];
|
||||
const index_t InRightPadH = input_right_pads[1];
|
||||
const index_t InRightPadW = input_right_pads[2];
|
||||
|
||||
const index_t GemmKTotal = N * Do * Ho * Wo;
|
||||
const index_t GemmM = K * NumBatchToMerge;
|
||||
const index_t GemmN = C * Z * X * Y * NumBatchToMerge;
|
||||
|
||||
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
|
||||
|
||||
const index_t GemmKBatch = batch_k;
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
|
||||
K0PerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
|
||||
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
|
||||
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides);
|
||||
const auto wei_grid_desc = make_wei_grid_desc<NDim>(K, Z, Y, X, C, weights_strides);
|
||||
|
||||
if constexpr(ConvBackwardWeightSpecialization ==
|
||||
device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// A: output tensor
|
||||
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
make_tuple(
|
||||
make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_merge_transform(make_tuple(NumBatchToMerge, GemmM / NumBatchToMerge))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// B: input tensor
|
||||
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(
|
||||
make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_merge_transform(make_tuple(NumBatchToMerge, GemmN / NumBatchToMerge))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
// A: output tensor
|
||||
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
make_tuple(
|
||||
make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_merge_transform(make_tuple(NumBatchToMerge, GemmM / NumBatchToMerge))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// B: input tensor
|
||||
const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Di, InLeftPadD, InRightPadD),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(NumBatchToMerge),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}));
|
||||
|
||||
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_dip_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(NumBatchToMerge),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1, 2>{},
|
||||
Sequence<3, 4>{},
|
||||
Sequence<5, 6>{},
|
||||
Sequence<7>{},
|
||||
Sequence<8>{}));
|
||||
|
||||
const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_n_z_do_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Z, Y, X, NumBatchToMerge, C)),
|
||||
make_merge_transform(make_tuple(N, Do, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5, 7, 8>{}, Sequence<0, 2, 4, 6>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// Padd
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmKBatch * GemmK0),
|
||||
make_right_pad_transform(GemmM, PadGemmM),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmKBatch * GemmK0),
|
||||
make_right_pad_transform(GemmN, PadGemmN),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto wei_gemmm_gemmn_pad_grid_desc =
|
||||
transform_tensor_descriptor(wei_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
|
||||
wei_gemmm_gemmn_pad_grid_desc);
|
||||
}
|
||||
} // function end
|
||||
};
|
||||
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -124,7 +124,7 @@ struct EnvVar
|
||||
|
||||
#define CK_DECLARE_ENV_VAR_STR(name) CK_DECLARE_ENV_VAR(name, std::string, "")
|
||||
|
||||
#define ENV(name) \
|
||||
#define CK_ENV(name) \
|
||||
ck::env::name {}
|
||||
|
||||
template <class EnvVar>
|
||||
|
||||
@@ -29,6 +29,25 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz
|
||||
return __builtin_bit_cast(int32x4_t, res);
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
// below type indicate the data type used for buffer load inline asm
|
||||
// clang-format off
|
||||
template<index_t N, typename T> struct buffer_load_trait;
|
||||
|
||||
template<typename T> struct buffer_load_trait<16, T> { using payload_t = fp32x4_t; };
|
||||
template<typename T> struct buffer_load_trait<8 , T> { using payload_t = fp32x2_t; };
|
||||
template<typename T> struct buffer_load_trait<4 , T> { using payload_t = float; };
|
||||
template<typename T> struct buffer_load_trait<2 , T> { using payload_t = float; };
|
||||
template<typename T> struct buffer_load_trait<1 , T> { using payload_t = float; };
|
||||
|
||||
#if CK_TILE_BUFFER_LOAD_RAW_BF16_WA
|
||||
template<> struct buffer_load_trait<16, thread_buffer<bf16_t, 8>> { using payload_t = bf16x8_t; };
|
||||
template<> struct buffer_load_trait<8 , thread_buffer<bf16_t, 4>> { using payload_t = bf16x4_t; };
|
||||
template<> struct buffer_load_trait<4 , thread_buffer<bf16_t, 2>> { using payload_t = bf16x2_t; };
|
||||
#endif
|
||||
// clang-format on
|
||||
} // namespace impl
|
||||
|
||||
// TODO: glc/slc/...
|
||||
template <index_t bytes>
|
||||
struct buffer_load;
|
||||
@@ -48,7 +67,7 @@ struct buffer_load<16>
|
||||
index_t /*flag*/ = 0)
|
||||
{
|
||||
static_assert(sizeof(T) == 16);
|
||||
using mbuf_t = fp32x4_t;
|
||||
using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
|
||||
asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
|
||||
@@ -68,7 +87,7 @@ struct buffer_load<8>
|
||||
index_t /*flag*/ = 0)
|
||||
{
|
||||
static_assert(sizeof(T) == 8);
|
||||
using mbuf_t = fp32x2_t;
|
||||
using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
|
||||
asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
|
||||
@@ -88,7 +107,7 @@ struct buffer_load<4>
|
||||
index_t /*flag*/ = 0)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
using mbuf_t = float;
|
||||
using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
|
||||
asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
|
||||
@@ -108,7 +127,7 @@ struct buffer_load<2>
|
||||
index_t /*flag*/ = 0)
|
||||
{
|
||||
static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
|
||||
using mbuf_t = float;
|
||||
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
|
||||
asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
|
||||
@@ -128,7 +147,7 @@ struct buffer_load<1>
|
||||
index_t /*flag*/ = 0)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
using mbuf_t = float;
|
||||
using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
|
||||
asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4"
|
||||
: "+v"(reinterpret_cast<mbuf_t&>(value))
|
||||
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
|
||||
@@ -152,7 +171,7 @@ struct buffer_load_if<16>
|
||||
{
|
||||
static_assert(sizeof(T) == 16);
|
||||
auto saved_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = fp32x4_t;
|
||||
using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
|
||||
static_assert(sizeof(mbuf_t) == sizeof(T));
|
||||
asm volatile(
|
||||
"v_cmpx_le_u32 exec, 1, %5\n"
|
||||
@@ -177,7 +196,7 @@ struct buffer_load_if<8>
|
||||
{
|
||||
static_assert(sizeof(T) == 8);
|
||||
auto saved_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = fp32x2_t;
|
||||
using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
|
||||
asm volatile(
|
||||
"v_cmpx_le_u32 exec, 1, %5\n"
|
||||
"buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4\n"
|
||||
@@ -201,7 +220,7 @@ struct buffer_load_if<4>
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
auto saved_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = float;
|
||||
using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
|
||||
asm volatile(
|
||||
"v_cmpx_le_u32 exec, 1, %5\n"
|
||||
"buffer_load_dword %0, %1, %2, %3 offen offset:%4\n"
|
||||
@@ -225,7 +244,7 @@ struct buffer_load_if<2>
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
auto saved_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = float;
|
||||
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
|
||||
asm volatile(
|
||||
"v_cmpx_le_u32 exec, 1, %5\n"
|
||||
"buffer_load_ushort %0, %1, %2, %3 offen offset:%4\n"
|
||||
@@ -249,7 +268,7 @@ struct buffer_load_if<1>
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
auto saved_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = float;
|
||||
using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
|
||||
asm volatile(
|
||||
"v_cmpx_le_u32 exec, 1, %5\n"
|
||||
"buffer_load_ubyte %0, %1, %2, %3 offen offset:%4\n"
|
||||
|
||||
@@ -3,6 +3,21 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__)
|
||||
#define __gfx9__
|
||||
#endif
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#define __gfx94__
|
||||
#endif
|
||||
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
|
||||
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__)
|
||||
#define __gfx103__
|
||||
#endif
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
|
||||
#define __gfx11__
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
|
||||
#include "hip/hip_runtime.h"
|
||||
#include "hip/hip_fp16.h"
|
||||
@@ -109,15 +124,13 @@
|
||||
// buffer atomic add: floating point
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
|
||||
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__) // for GPU code
|
||||
#elif defined(__gfx9__) // for GPU code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
|
||||
#else // for GPU code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
|
||||
#endif
|
||||
|
||||
#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__)) // for GPU code
|
||||
#if(defined(__gfx90a__) || defined(__gfx94__)) // for GPU code
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1
|
||||
#else
|
||||
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0
|
||||
@@ -137,13 +150,12 @@
|
||||
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0xffffffff
|
||||
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__) // for GPU code
|
||||
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || \
|
||||
defined(__gfx9__) // for GPU code
|
||||
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000
|
||||
#elif defined(__gfx1030__) // for GPU code
|
||||
#elif defined(__gfx103__) // for GPU code
|
||||
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000
|
||||
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
|
||||
#elif defined(__gfx11__) // for GPU code
|
||||
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
|
||||
#endif
|
||||
|
||||
@@ -159,3 +171,7 @@
|
||||
#ifndef CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
#define CK_TILE_FMHA_FWD_FAST_EXP2 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA
|
||||
#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1
|
||||
#endif
|
||||
|
||||
@@ -55,7 +55,7 @@ struct alignas(1) float8_e4m3_t
|
||||
{
|
||||
static constexpr int exponent = 4;
|
||||
static constexpr int mantissa = 3;
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#if defined(__gfx94__)
|
||||
static constexpr int bias = 1 << (exponent - 1); // NANOO
|
||||
#else
|
||||
static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE
|
||||
@@ -113,7 +113,7 @@ struct alignas(1) float8_e5m2_t
|
||||
{
|
||||
static constexpr int exponent = 5;
|
||||
static constexpr int mantissa = 2;
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#if defined(__gfx94__)
|
||||
static constexpr int bias = 1 << (exponent - 1); // NANOO
|
||||
#else
|
||||
static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE
|
||||
@@ -470,7 +470,7 @@ CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_sr_raw(float x)
|
||||
{
|
||||
constexpr int seed = 42;
|
||||
uint32_t rng = prand_generator_t<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#if defined(__gfx94__)
|
||||
float max_fp8 = 240.0f;
|
||||
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
|
||||
union
|
||||
@@ -500,7 +500,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_sr_raw(float x)
|
||||
{
|
||||
constexpr int seed = 42;
|
||||
uint32_t rng = prand_generator_t<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#if defined(__gfx94__)
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
@@ -526,7 +526,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_sr_raw(float x)
|
||||
|
||||
CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_rtn_raw(float x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#if defined(__gfx94__)
|
||||
float max_fp8 = 240.0f;
|
||||
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
|
||||
union
|
||||
@@ -554,7 +554,7 @@ CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_rtn_raw(float x)
|
||||
}
|
||||
CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_rtn_raw(float x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#if defined(__gfx94__)
|
||||
union
|
||||
{
|
||||
float fval;
|
||||
@@ -598,7 +598,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_raw(float x, constant<rounding>)
|
||||
|
||||
CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#if defined(__gfx94__)
|
||||
float fval;
|
||||
uint32_t i32val = static_cast<uint32_t>(x);
|
||||
fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
|
||||
@@ -612,7 +612,7 @@ CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x)
|
||||
|
||||
CK_TILE_HOST_DEVICE float bf8_to_float_raw(bf8_raw_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#if defined(__gfx94__)
|
||||
float fval;
|
||||
uint32_t i32val = static_cast<uint32_t>(x);
|
||||
fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0);
|
||||
@@ -656,7 +656,7 @@ struct numeric_traits<fp8_t>
|
||||
{
|
||||
static constexpr int exp = 4;
|
||||
static constexpr int mant = 3;
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#if defined(__gfx94__)
|
||||
static constexpr int bias = 8;
|
||||
#else
|
||||
static constexpr int bias = 7;
|
||||
@@ -668,7 +668,7 @@ struct numeric_traits<bf8_t>
|
||||
{
|
||||
static constexpr int exp = 5;
|
||||
static constexpr int mant = 2;
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#if defined(__gfx94__)
|
||||
static constexpr int bias = 16;
|
||||
#else
|
||||
static constexpr int bias = 15; // IEEE
|
||||
|
||||
@@ -129,8 +129,8 @@ constexpr double fp16_to_double_hip(const fp16_hip_t& x)
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr fp16_hip_t float_to_fp16_hip(const float& x)
|
||||
{
|
||||
return __float2half(x);
|
||||
// return static_cast<fp16_hip_t>(x);
|
||||
// return __float2half(x);
|
||||
return static_cast<fp16_hip_t>(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
|
||||
@@ -56,7 +56,6 @@ CK_TILE_LEFT_UNARY_OP(+)
|
||||
CK_TILE_LEFT_UNARY_OP(-)
|
||||
CK_TILE_LEFT_UNARY_OP(~)
|
||||
CK_TILE_LEFT_UNARY_OP(!)
|
||||
CK_TILE_LEFT_UNARY_OP(*)
|
||||
|
||||
CK_TILE_BINARY_OP(+)
|
||||
CK_TILE_BINARY_OP(-)
|
||||
|
||||
@@ -112,7 +112,7 @@ namespace impl {
|
||||
template <typename OutDataType, typename InTensor>
|
||||
CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#if defined(__gfx94__)
|
||||
// This API is designed to use the _pk_ serious of function
|
||||
constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
|
||||
|
||||
|
||||
@@ -21,3 +21,4 @@
|
||||
#include "ck_tile/host/reference/reference_reduce.hpp"
|
||||
#include "ck_tile/host/reference/reference_softmax.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
#include "ck_tile/host/timer.hpp"
|
||||
|
||||
@@ -27,7 +27,14 @@ struct DeviceMem
|
||||
DeviceMem() : mpDeviceBuf(nullptr), mMemSize(0) {}
|
||||
DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
|
||||
{
|
||||
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
|
||||
if(mMemSize != 0)
|
||||
{
|
||||
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
|
||||
}
|
||||
else
|
||||
{
|
||||
mpDeviceBuf = nullptr;
|
||||
}
|
||||
}
|
||||
void Realloc(std::size_t mem_size)
|
||||
{
|
||||
@@ -36,7 +43,14 @@ struct DeviceMem
|
||||
HIP_CHECK_ERROR(hipFree(mpDeviceBuf));
|
||||
}
|
||||
mMemSize = mem_size;
|
||||
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
|
||||
if(mMemSize != 0)
|
||||
{
|
||||
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
|
||||
}
|
||||
else
|
||||
{
|
||||
mpDeviceBuf = nullptr;
|
||||
}
|
||||
}
|
||||
void* GetDeviceBuffer() const { return mpDeviceBuf; }
|
||||
std::size_t GetBufferSize() const { return mMemSize; }
|
||||
@@ -47,15 +61,18 @@ struct DeviceMem
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("ToDevice with an empty pointer");
|
||||
}
|
||||
// else
|
||||
// {
|
||||
// throw std::runtime_error("ToDevice with an empty pointer");
|
||||
// }
|
||||
}
|
||||
void ToDevice(const void* p, const std::size_t cpySize) const
|
||||
{
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemcpy(mpDeviceBuf, const_cast<void*>(p), cpySize, hipMemcpyHostToDevice));
|
||||
if(mpDeviceBuf)
|
||||
{
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemcpy(mpDeviceBuf, const_cast<void*>(p), cpySize, hipMemcpyHostToDevice));
|
||||
}
|
||||
}
|
||||
void FromDevice(void* p) const
|
||||
{
|
||||
@@ -63,14 +80,17 @@ struct DeviceMem
|
||||
{
|
||||
HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("FromDevice with an empty pointer");
|
||||
}
|
||||
// else
|
||||
// {
|
||||
// throw std::runtime_error("FromDevice with an empty pointer");
|
||||
// }
|
||||
}
|
||||
void FromDevice(void* p, const std::size_t cpySize) const
|
||||
{
|
||||
HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
|
||||
if(mpDeviceBuf)
|
||||
{
|
||||
HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
|
||||
}
|
||||
}
|
||||
void SetZero() const
|
||||
{
|
||||
@@ -82,13 +102,16 @@ struct DeviceMem
|
||||
template <typename T>
|
||||
void SetValue(T x) const
|
||||
{
|
||||
if(mMemSize % sizeof(T) != 0)
|
||||
if(mpDeviceBuf)
|
||||
{
|
||||
throw std::runtime_error("wrong! not entire DeviceMem will be set");
|
||||
}
|
||||
if(mMemSize % sizeof(T) != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! not entire DeviceMem will be set");
|
||||
}
|
||||
|
||||
// TODO: call a gpu kernel to set the value (?)
|
||||
set_buffer_value<T><<<1, 1024>>>(static_cast<T*>(mpDeviceBuf), x, mMemSize / sizeof(T));
|
||||
// TODO: call a gpu kernel to set the value (?)
|
||||
set_buffer_value<T><<<1, 1024>>>(static_cast<T*>(mpDeviceBuf), x, mMemSize / sizeof(T));
|
||||
}
|
||||
}
|
||||
~DeviceMem()
|
||||
{
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include "ck_tile/host/timer.hpp"
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cstddef>
|
||||
|
||||
@@ -14,153 +15,92 @@ template <int MaxThreadPerBlock, int MinBlockPerCu, typename Kernel, typename...
|
||||
#if CK_TILE_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
|
||||
#endif
|
||||
__global__ void kentry(Kernel f, Args... args)
|
||||
__global__ void kentry(Args... args)
|
||||
{
|
||||
f(args...);
|
||||
}
|
||||
|
||||
template <typename... Args, typename F>
|
||||
CK_TILE_HOST float launch_and_time_kernel(const stream_config& s,
|
||||
F kernel,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t lds_byte,
|
||||
Args... args)
|
||||
{
|
||||
#if CK_TILE_TIME_KERNEL
|
||||
if(s.time_kernel_)
|
||||
{
|
||||
// warm up
|
||||
for(int i = 0; i < s.cold_niters_; ++i)
|
||||
{
|
||||
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
|
||||
hip_check_error(hipGetLastError());
|
||||
}
|
||||
|
||||
const int nrepeat = s.nrepeat_;
|
||||
hipEvent_t start, stop;
|
||||
|
||||
HIP_CHECK_ERROR(hipEventCreate(&start));
|
||||
HIP_CHECK_ERROR(hipEventCreate(&stop));
|
||||
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
HIP_CHECK_ERROR(hipEventRecord(start, s.stream_id_));
|
||||
|
||||
for(int i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
|
||||
hip_check_error(hipGetLastError());
|
||||
}
|
||||
|
||||
HIP_CHECK_ERROR(hipEventRecord(stop, s.stream_id_));
|
||||
HIP_CHECK_ERROR(hipEventSynchronize(stop));
|
||||
|
||||
float total_time = 0;
|
||||
|
||||
HIP_CHECK_ERROR(hipEventElapsedTime(&total_time, start, stop));
|
||||
|
||||
return total_time / nrepeat;
|
||||
}
|
||||
else
|
||||
{
|
||||
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
|
||||
hip_check_error(hipGetLastError());
|
||||
return 0;
|
||||
}
|
||||
#else
|
||||
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
|
||||
hip_check_error(hipGetLastError());
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename... Args, typename F, typename PreProcessFunc>
|
||||
CK_TILE_HOST float launch_and_time_kernel_with_preprocess(const stream_config& s,
|
||||
PreProcessFunc preprocess,
|
||||
F kernel,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t lds_byte,
|
||||
Args... args)
|
||||
{
|
||||
#if CK_TILE_TIME_KERNEL
|
||||
if(s.time_kernel_)
|
||||
{
|
||||
#if CK_TILE_DEBUG_LOG
|
||||
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
|
||||
__func__,
|
||||
grid_dim.x,
|
||||
grid_dim.y,
|
||||
grid_dim.z,
|
||||
block_dim.x,
|
||||
block_dim.y,
|
||||
block_dim.z);
|
||||
|
||||
printf("Warm up 1 time\n");
|
||||
#endif
|
||||
// warm up
|
||||
preprocess();
|
||||
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
|
||||
hip_check_error(hipGetLastError());
|
||||
|
||||
const int nrepeat = 10;
|
||||
#if CK_TILE_DEBUG_LOG
|
||||
printf("Start running %d times...\n", nrepeat);
|
||||
#endif
|
||||
hipEvent_t start, stop;
|
||||
|
||||
HIP_CHECK_ERROR(hipEventCreate(&start));
|
||||
HIP_CHECK_ERROR(hipEventCreate(&stop));
|
||||
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
HIP_CHECK_ERROR(hipEventRecord(start, s.stream_id_));
|
||||
|
||||
for(int i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
preprocess();
|
||||
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
|
||||
hip_check_error(hipGetLastError());
|
||||
}
|
||||
|
||||
HIP_CHECK_ERROR(hipEventRecord(stop, s.stream_id_));
|
||||
HIP_CHECK_ERROR(hipEventSynchronize(stop));
|
||||
|
||||
float total_time = 0;
|
||||
|
||||
HIP_CHECK_ERROR(hipEventElapsedTime(&total_time, start, stop));
|
||||
|
||||
return total_time / nrepeat;
|
||||
}
|
||||
else
|
||||
{
|
||||
preprocess();
|
||||
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
|
||||
hip_check_error(hipGetLastError());
|
||||
|
||||
return 0;
|
||||
}
|
||||
#else
|
||||
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
|
||||
hip_check_error(hipGetLastError());
|
||||
|
||||
return 0;
|
||||
#endif
|
||||
Kernel{}(args...);
|
||||
}
|
||||
|
||||
//
|
||||
// return a anonymous functor(lambda) to be called later
|
||||
// the KernelImpl should be a class without non-static data member, or let's say
|
||||
// can be instantiate with "KernelImpl{}"
|
||||
//
|
||||
// the "static __device__ operator()(some_arg)" is the entry point of KernelImpl
|
||||
//
|
||||
template <int MaxThreadPerBlock = CK_TILE_MAX_THREAD_PER_BLOCK,
|
||||
int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU,
|
||||
typename KernelImpl,
|
||||
typename... Args>
|
||||
CK_TILE_HOST float launch_kernel(const stream_config& s,
|
||||
KernelImpl kernel_impl,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t dynamic_smem_byte,
|
||||
Args... args)
|
||||
CK_TILE_HOST auto
|
||||
make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
|
||||
{
|
||||
const auto kernel = kentry<MaxThreadPerBlock, MinBlockPerCu, KernelImpl, Args...>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
s, kernel, grid_dim, block_dim, dynamic_smem_byte, kernel_impl, args...);
|
||||
return [=](const stream_config& s) {
|
||||
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
|
||||
};
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
/*
|
||||
* launch_kernel()
|
||||
*
|
||||
* this is the function to launch arbitrary number of kernels with optional timer(selected by stream_config)
|
||||
* the callables should have signature as "operator()(const stream_config& s){ ... }" to call
|
||||
*
|
||||
* the simplest way is pass in a lambda function, with "[=](const stream_config& s){ call_your_kernel_here() }"
|
||||
* as signature, for the callable (pay attention to the capture list)
|
||||
*
|
||||
* e.g.
|
||||
* ck_tile::launch_kernel(s,
|
||||
* [=](const stream_config& s){ hipMemset(ptr, 0, size) },
|
||||
* [=](const stream_config& s){ some_kernel<<<grids, blocks>>>(arg); }
|
||||
* );
|
||||
*
|
||||
* if you use ck_tile kernel, or similiar to this style (structure with "static __device__ operator()(...){}")
|
||||
* you can pass your kernel to ck_tile::make_kernel(), which will create a anonymous functor for you,
|
||||
* then pass it to ck_tile::launch_kernel()
|
||||
*
|
||||
* e.g.
|
||||
* ck_tile::launch_kernel(s,
|
||||
* ck_tile::make_kernel<T0, B0>(kernel_0{}, grids0, blocks0, 0, kargs0),
|
||||
* ck_tile::make_kernel<T0, B1>(kernel_1{}, grids1, blocks1, 0, kargs1),
|
||||
* ...);
|
||||
**/
|
||||
// clang-format on
|
||||
template <typename... Callables>
|
||||
CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables)
|
||||
{
|
||||
// clang-format off
|
||||
if(!s.time_kernel_) {
|
||||
(callables(s),...); hip_check_error(hipGetLastError());
|
||||
return 0;
|
||||
}
|
||||
if(s.is_gpu_timer_) {
|
||||
gpu_timer timer {};
|
||||
|
||||
// warmup
|
||||
for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } hip_check_error(hipGetLastError());
|
||||
|
||||
timer.start(s.stream_id_);
|
||||
for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } hip_check_error(hipGetLastError());
|
||||
timer.stop(s.stream_id_);
|
||||
|
||||
return timer.duration() / s.nrepeat_;
|
||||
}
|
||||
else {
|
||||
cpu_timer timer {};
|
||||
|
||||
// warmup
|
||||
for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } hip_check_error(hipGetLastError());
|
||||
|
||||
timer.start(s.stream_id_);
|
||||
for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } hip_check_error(hipGetLastError());
|
||||
timer.stop(s.stream_id_);
|
||||
|
||||
return timer.duration() / s.nrepeat_;
|
||||
}
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,6 +6,22 @@
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace ck_tile {
|
||||
/*
|
||||
* construct this structure with behavior as:
|
||||
*
|
||||
* // create stream config with default stream(NULL), and not timing the kernel
|
||||
* stream_config s = stream_config{};
|
||||
*
|
||||
* // create stream config with _some_stream_id_, and not timing the kernel
|
||||
* stream_config s = stream_config{_some_stream_id_};
|
||||
*
|
||||
* // create stream config with _some_stream_id_, and benchmark with warmup/repeat as default
|
||||
* stream_config s = stream_config{_some_stream_id_, true};
|
||||
*
|
||||
* // create stream config with _some_stream_id_, and benchmark using cpu timer
|
||||
* stream_config s = stream_config{_some_stream_id_, true, 0, 3, 10, false};
|
||||
**/
|
||||
|
||||
struct stream_config
|
||||
{
|
||||
hipStream_t stream_id_ = nullptr;
|
||||
@@ -13,5 +29,6 @@ struct stream_config
|
||||
int log_level_ = 0;
|
||||
int cold_niters_ = 3;
|
||||
int nrepeat_ = 10;
|
||||
bool is_gpu_timer_ = true; // keep compatible
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
79
include/ck_tile/host/timer.hpp
Normal file
79
include/ck_tile/host/timer.hpp
Normal file
@@ -0,0 +1,79 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cstddef>
|
||||
#include <chrono>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct gpu_timer
|
||||
{
|
||||
CK_TILE_HOST gpu_timer()
|
||||
{
|
||||
HIP_CHECK_ERROR(hipEventCreate(&start_evt));
|
||||
HIP_CHECK_ERROR(hipEventCreate(&stop_evt));
|
||||
}
|
||||
|
||||
CK_TILE_HOST ~gpu_timer() noexcept(false)
|
||||
{
|
||||
HIP_CHECK_ERROR(hipEventDestroy(start_evt));
|
||||
HIP_CHECK_ERROR(hipEventDestroy(stop_evt));
|
||||
}
|
||||
|
||||
CK_TILE_HOST void start(const hipStream_t& s)
|
||||
{
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
HIP_CHECK_ERROR(hipEventRecord(start_evt, s));
|
||||
}
|
||||
|
||||
CK_TILE_HOST void stop(const hipStream_t& s)
|
||||
{
|
||||
HIP_CHECK_ERROR(hipEventRecord(stop_evt, s));
|
||||
HIP_CHECK_ERROR(hipEventSynchronize(stop_evt));
|
||||
}
|
||||
// return in ms
|
||||
CK_TILE_HOST float duration() const
|
||||
{
|
||||
float ms = 0;
|
||||
HIP_CHECK_ERROR(hipEventElapsedTime(&ms, start_evt, stop_evt));
|
||||
return ms;
|
||||
}
|
||||
|
||||
private:
|
||||
hipEvent_t start_evt, stop_evt;
|
||||
};
|
||||
|
||||
struct cpu_timer
|
||||
{
|
||||
// torch.utils.benchmark.Timer(), there is a sync inside each timer callback
|
||||
CK_TILE_HOST void start(const hipStream_t&)
|
||||
{
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
start_tick = std::chrono::high_resolution_clock::now();
|
||||
}
|
||||
// torch.utils.benchmark.Timer(), there is a sync inside each timer callback
|
||||
CK_TILE_HOST void stop(const hipStream_t&)
|
||||
{
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
stop_tick = std::chrono::high_resolution_clock::now();
|
||||
}
|
||||
// return in ms
|
||||
CK_TILE_HOST float duration() const
|
||||
{
|
||||
double sec =
|
||||
std::chrono::duration_cast<std::chrono::duration<double>>(stop_tick - start_tick)
|
||||
.count();
|
||||
return static_cast<float>(sec * 1e3);
|
||||
}
|
||||
|
||||
private:
|
||||
std::chrono::time_point<std::chrono::high_resolution_clock> start_tick;
|
||||
std::chrono::time_point<std::chrono::high_resolution_clock> stop_tick;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -23,13 +23,13 @@ VERTICAL:
|
||||
[0] 1 2 3 4 5
|
||||
[0] 1 2 3 4 5
|
||||
|
||||
TOP_LEFT:
|
||||
TOP_LEFT(but negative):
|
||||
[0] 1 2 3 4 5
|
||||
1 [0] 1 2 3 4
|
||||
2 1 [0] 1 2 3
|
||||
3 2 1 [0] 1 2
|
||||
|
||||
FROM_BOTTOM_RIGHT:
|
||||
FROM_BOTTOM_RIGHT(but negative):
|
||||
2 1 [0] 1 2 3
|
||||
3 2 1 [0] 1 2
|
||||
4 3 2 1 [0] 1
|
||||
|
||||
@@ -79,7 +79,7 @@ struct FmhaFwdKernel
|
||||
return n.empty() ? n : std::string("p") + n; }();
|
||||
return
|
||||
_SS_("fmha_fwd_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s<QDataType>::name) +
|
||||
"_" + (kIsGroupMode ? "group" : "batch") + "_" +
|
||||
"_" + (kIsGroupMode ? "group" : "batch") + "_" + _SS_(TilePartitioner::name) + "_"
|
||||
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
|
||||
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" +
|
||||
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
|
||||
|
||||
@@ -18,6 +18,8 @@ struct FmhaFwdTilePartitioner
|
||||
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
|
||||
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
|
||||
|
||||
static constexpr const char* name = "shb";
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
ck_tile::index_t nhead_,
|
||||
ck_tile::index_t seqlen_q_,
|
||||
@@ -51,4 +53,53 @@ struct FmhaFwdTilePartitioner
|
||||
}
|
||||
};
|
||||
|
||||
template <typename BlockFmhaShape_>
|
||||
using FmhaFwdTilePartitioner_SHB = FmhaFwdTilePartitioner<BlockFmhaShape_>;
|
||||
|
||||
template <typename BlockFmhaShape_>
|
||||
struct FmhaFwdTilePartitioner_HBS
|
||||
{
|
||||
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
|
||||
|
||||
static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
|
||||
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
|
||||
|
||||
static constexpr const char* name = "hbs";
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
ck_tile::index_t nhead_,
|
||||
ck_tile::index_t seqlen_q_,
|
||||
ck_tile::index_t hdim_v_)
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
return dim3(nhead_,
|
||||
batch_size_,
|
||||
ck_tile::integer_divide_ceil(seqlen_q_, kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v_, kN1));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v)
|
||||
{
|
||||
// const index_t num_tile_m0 = seqlen_q / kM0;
|
||||
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
|
||||
|
||||
const index_t i_block = blockIdx.z;
|
||||
const index_t i_nhead = blockIdx.x;
|
||||
const index_t i_batch = blockIdx.y;
|
||||
|
||||
const auto f = [](index_t dividend, index_t divisor) {
|
||||
index_t quotient = dividend / divisor;
|
||||
index_t modulus = dividend - quotient * divisor;
|
||||
return ck_tile::make_tuple(quotient, modulus);
|
||||
};
|
||||
|
||||
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -36,8 +36,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__)
|
||||
#if defined(__gfx9__)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
@@ -49,8 +48,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__)
|
||||
#if defined(__gfx9__)
|
||||
return bit_cast<CVecType>(
|
||||
__builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
|
||||
#else
|
||||
@@ -89,8 +87,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__)
|
||||
#if defined(__gfx9__)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
@@ -102,8 +99,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__)
|
||||
#if defined(__gfx9__)
|
||||
return bit_cast<CVecType>(
|
||||
__builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
|
||||
#else
|
||||
@@ -143,7 +139,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#if defined(__gfx90a__) || defined(__gfx94__)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
#elif defined(__gfx908__)
|
||||
static_for<0, 2, 1>{}([&](auto k) {
|
||||
@@ -167,7 +163,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#if defined(__gfx90a__) || defined(__gfx94__)
|
||||
return bit_cast<CVecType>(
|
||||
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
|
||||
#elif defined(__gfx908__)
|
||||
@@ -220,7 +216,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#if defined(__gfx90a__) || defined(__gfx94__)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
#elif defined(__gfx908__)
|
||||
static_for<0, 2, 1>{}([&](auto k) {
|
||||
@@ -244,7 +240,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#if defined(__gfx90a__) || defined(__gfx94__)
|
||||
return bit_cast<CVecType>(
|
||||
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
|
||||
#elif defined(__gfx908__)
|
||||
@@ -299,7 +295,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#if defined(__gfx94__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
@@ -333,7 +329,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#if defined(__gfx94__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
|
||||
|
||||
@@ -35,14 +35,24 @@ template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec>
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec,
|
||||
BlockGemmPipelineScheduler Scheduler,
|
||||
BlockGemmPipelineVersion PipelineVersion>
|
||||
using device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer|
|
||||
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
|
||||
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| |
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<1, 4, 8, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 8, 1, 8>, 1>
|
||||
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumBatch|
|
||||
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
|
||||
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>,
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>,
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4>,
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 32, 32, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 8>,
|
||||
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>,
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>,
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4>,
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
|
||||
@@ -352,7 +352,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
@@ -421,7 +423,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -114,7 +114,19 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -205,7 +217,19 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
|
||||
@@ -36,6 +36,13 @@ function(add_instance_library INSTANCE_NAME)
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
if(INSTANCES_ONLY)
|
||||
set(INST_TARGETS ${DEFAULT_GPU_TARGETS})
|
||||
else()
|
||||
set(INST_TARGETS ${GPU_TARGETS})
|
||||
endif()
|
||||
|
||||
# Do not build DL instances if DL_KERNELS macro is not set
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
|
||||
@@ -45,21 +52,40 @@ function(add_instance_library INSTANCE_NAME)
|
||||
endforeach()
|
||||
# Do not build XDL instances if gfx9 targets are not on the target list
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
|
||||
if(NOT INST_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
|
||||
message("removing xdl instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
# Do not build WMMA instances if gfx11 targets are not on the target list
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma")
|
||||
if(NOT INST_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma")
|
||||
message("removing wmma instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#only continue if there are some source files left on the list
|
||||
if(ARGN)
|
||||
add_library(${INSTANCE_NAME} OBJECT ${ARGN})
|
||||
set(INST_OBJ)
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(INSTANCES_ONLY)
|
||||
set(INST_TARGETS ${DEFAULT_GPU_TARGETS})
|
||||
else()
|
||||
set(INST_TARGETS ${GPU_TARGETS})
|
||||
endif()
|
||||
if(source MATCHES "_xdl")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103)
|
||||
elseif(ARGN MATCHES "_wmma")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
|
||||
endif()
|
||||
set(offload_targets)
|
||||
foreach(target IN LISTS INST_TARGETS)
|
||||
string(APPEND offload_targets "--offload-arch=${target} ")
|
||||
endforeach()
|
||||
set_source_files_properties(${source} PROPERTIES COMPILE_FLAGS ${offload_targets})
|
||||
list(APPEND INST_OBJ ${source})
|
||||
endforeach()
|
||||
add_library(${INSTANCE_NAME} OBJECT ${INST_OBJ})
|
||||
target_compile_features(${INSTANCE_NAME} PUBLIC)
|
||||
set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
clang_tidy_check(${INSTANCE_NAME})
|
||||
@@ -131,6 +157,14 @@ FOREACH(subdir_path ${dir_list})
|
||||
if(NOT DEFINED DTYPES)
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
|
||||
if(INSTANCES_ONLY)
|
||||
set(INST_TARGETS ${DEFAULT_GPU_TARGETS})
|
||||
else()
|
||||
set(INST_TARGETS ${GPU_TARGETS})
|
||||
endif()
|
||||
|
||||
|
||||
if(("${cmake_instance}" MATCHES "quantization") AND (DEFINED DTYPES) AND (NOT DTYPES MATCHES "int8"))
|
||||
message("quantization instances will not be built!")
|
||||
set(add_inst 0)
|
||||
@@ -139,23 +173,23 @@ FOREACH(subdir_path ${dir_list})
|
||||
message("Found only dl instances, but DL_KERNELS is not set. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "ONLY XDL_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx9"))
|
||||
if(("${cmake_instance}" MATCHES "ONLY XDL_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx9"))
|
||||
message("Found only xdl instances, but gfx9 is not on the targets list. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11"))
|
||||
if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11"))
|
||||
message("Found only wmma instances, but gfx11 is not on the targets list. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "ONLY XDL_AND_DL_KERNELS") AND (NOT DEFINED DL_KERNELS) AND (NOT GPU_TARGETS MATCHES "gfx9"))
|
||||
if(("${cmake_instance}" MATCHES "ONLY XDL_AND_DL_KERNELS") AND (NOT DEFINED DL_KERNELS) AND (NOT INST_TARGETS MATCHES "gfx9"))
|
||||
message("Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9"))
|
||||
if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx9"))
|
||||
message("Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS))
|
||||
if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS))
|
||||
message("Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
|
||||
@@ -2,9 +2,14 @@
|
||||
set(GEMM_MULTI_ABD_INSTANCES)
|
||||
|
||||
list(APPEND GEMM_MULTI_ABD_INSTANCES
|
||||
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
|
||||
device_gemm_xdl_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
|
||||
device_gemm_xdl_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
|
||||
device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
|
||||
device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp
|
||||
|
||||
device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
|
||||
device_gemm_xdl_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
|
||||
device_gemm_xdl_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
|
||||
device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
|
||||
|
||||
#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
|
||||
ck::Tuple<B0Layout, B1Layout>,
|
||||
ck::Tuple<>,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
ck::Tuple<B0DataType, B1DataType>,
|
||||
ck::Tuple<>,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
Multiply,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<
|
||||
ck::Tuple<B0Layout, B1Layout>,
|
||||
ck::Tuple<>,
|
||||
ck::Tuple<B0DataType, B1DataType>,
|
||||
ck::Tuple<>,
|
||||
Multiply,
|
||||
PassThrough,
|
||||
GemmMNKPadding,
|
||||
Interwave>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<
|
||||
ck::Tuple<B0Layout, B1Layout>,
|
||||
ck::Tuple<>,
|
||||
ck::Tuple<B0DataType, B1DataType>,
|
||||
ck::Tuple<>,
|
||||
Multiply,
|
||||
PassThrough,
|
||||
GemmMNKPadding,
|
||||
Interwave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,58 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
|
||||
|
||||
#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
|
||||
ck::Tuple<B0Layout, B1Layout>,
|
||||
ck::Tuple<D0Layout>,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
ck::Tuple<B0DataType, B1DataType>,
|
||||
ck::Tuple<D0DataType>,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
Multiply,
|
||||
Add>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<
|
||||
ck::Tuple<B0Layout, B1Layout>,
|
||||
ck::Tuple<D0Layout>,
|
||||
ck::Tuple<B0DataType, B1DataType>,
|
||||
ck::Tuple<D0DataType>,
|
||||
Multiply,
|
||||
Add,
|
||||
GemmMNKPadding,
|
||||
Interwave>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<
|
||||
ck::Tuple<B0Layout, B1Layout>,
|
||||
ck::Tuple<D0Layout>,
|
||||
ck::Tuple<B0DataType, B1DataType>,
|
||||
ck::Tuple<D0DataType>,
|
||||
Multiply,
|
||||
Add,
|
||||
GemmMNKPadding,
|
||||
Interwave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
@@ -52,112 +52,6 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances(
|
||||
Interwave>{});
|
||||
}
|
||||
|
||||
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
|
||||
ck::Tuple<B0Layout, B1Layout>,
|
||||
ck::Tuple<D0Layout>,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
ck::Tuple<B0DataType, B1DataType>,
|
||||
ck::Tuple<D0DataType>,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
Multiply,
|
||||
Add>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<
|
||||
ck::Tuple<B0Layout, B1Layout>,
|
||||
ck::Tuple<D0Layout>,
|
||||
ck::Tuple<B0DataType, B1DataType>,
|
||||
ck::Tuple<D0DataType>,
|
||||
Multiply,
|
||||
Add,
|
||||
GemmMNKPadding,
|
||||
Interwave>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<
|
||||
ck::Tuple<B0Layout, B1Layout>,
|
||||
ck::Tuple<D0Layout>,
|
||||
ck::Tuple<B0DataType, B1DataType>,
|
||||
ck::Tuple<D0DataType>,
|
||||
Multiply,
|
||||
Add,
|
||||
GemmMNKPadding,
|
||||
Interwave>{});
|
||||
}
|
||||
|
||||
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
|
||||
ck::Tuple<B0Layout, B1Layout>,
|
||||
ck::Tuple<>,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
ck::Tuple<B0DataType, B1DataType>,
|
||||
ck::Tuple<>,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
Multiply,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<
|
||||
ck::Tuple<B0Layout, B1Layout>,
|
||||
ck::Tuple<>,
|
||||
ck::Tuple<B0DataType, B1DataType>,
|
||||
ck::Tuple<>,
|
||||
Multiply,
|
||||
PassThrough,
|
||||
GemmMNKPadding,
|
||||
Interwave>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<
|
||||
ck::Tuple<B0Layout, B1Layout>,
|
||||
ck::Tuple<>,
|
||||
ck::Tuple<B0DataType, B1DataType>,
|
||||
ck::Tuple<>,
|
||||
Multiply,
|
||||
PassThrough,
|
||||
GemmMNKPadding,
|
||||
Interwave>{});
|
||||
}
|
||||
|
||||
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
|
||||
ck::Tuple<B0Layout, B1Layout>,
|
||||
ck::Tuple<>,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
ck::Tuple<B0DataType, B1DataType>,
|
||||
ck::Tuple<>,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
Multiply,
|
||||
FastGelu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<
|
||||
ck::Tuple<B0Layout, B1Layout>,
|
||||
ck::Tuple<>,
|
||||
ck::Tuple<B0DataType, B1DataType>,
|
||||
ck::Tuple<>,
|
||||
Multiply,
|
||||
FastGelu,
|
||||
GemmMNKPadding,
|
||||
Interwave>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<
|
||||
ck::Tuple<B0Layout, B1Layout>,
|
||||
ck::Tuple<>,
|
||||
ck::Tuple<B0DataType, B1DataType>,
|
||||
ck::Tuple<>,
|
||||
Multiply,
|
||||
FastGelu,
|
||||
GemmMNKPadding,
|
||||
Interwave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
|
||||
|
||||
#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
|
||||
ck::Tuple<B0Layout, B1Layout>,
|
||||
ck::Tuple<>,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
ck::Tuple<B0DataType, B1DataType>,
|
||||
ck::Tuple<>,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
Multiply,
|
||||
FastGelu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<
|
||||
ck::Tuple<B0Layout, B1Layout>,
|
||||
ck::Tuple<>,
|
||||
ck::Tuple<B0DataType, B1DataType>,
|
||||
ck::Tuple<>,
|
||||
Multiply,
|
||||
FastGelu,
|
||||
GemmMNKPadding,
|
||||
Interwave>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<
|
||||
ck::Tuple<B0Layout, B1Layout>,
|
||||
ck::Tuple<>,
|
||||
ck::Tuple<B0DataType, B1DataType>,
|
||||
ck::Tuple<>,
|
||||
Multiply,
|
||||
FastGelu,
|
||||
GemmMNKPadding,
|
||||
Interwave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,58 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
|
||||
|
||||
#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
|
||||
ck::Tuple<B0Layout>,
|
||||
ck::Tuple<B1Layout>,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
ck::Tuple<B0DataType>,
|
||||
ck::Tuple<B1DataType>,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
PassThrough,
|
||||
Multiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<B0Layout>,
|
||||
ck::Tuple<B1Layout>,
|
||||
ck::Tuple<B0DataType>,
|
||||
ck::Tuple<B1DataType>,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
GemmMNKPadding,
|
||||
Interwave>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<B0Layout>,
|
||||
ck::Tuple<B1Layout>,
|
||||
ck::Tuple<B0DataType>,
|
||||
ck::Tuple<B1DataType>,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
GemmMNKPadding,
|
||||
Interwave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,58 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
|
||||
|
||||
#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
|
||||
ck::Tuple<B0Layout>,
|
||||
ck::Tuple<D0Layout, B1Layout>,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
ck::Tuple<B0DataType>,
|
||||
ck::Tuple<D0DataType, B1DataType>,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
PassThrough,
|
||||
MultiplyAdd>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<
|
||||
ck::Tuple<B0Layout>,
|
||||
ck::Tuple<D0Layout, B1Layout>,
|
||||
ck::Tuple<B0DataType>,
|
||||
ck::Tuple<D0DataType, B1DataType>,
|
||||
PassThrough,
|
||||
MultiplyAdd,
|
||||
GemmMNKPadding,
|
||||
Interwave>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<
|
||||
ck::Tuple<B0Layout>,
|
||||
ck::Tuple<D0Layout, B1Layout>,
|
||||
ck::Tuple<B0DataType>,
|
||||
ck::Tuple<D0DataType, B1DataType>,
|
||||
PassThrough,
|
||||
MultiplyAdd,
|
||||
GemmMNKPadding,
|
||||
Interwave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
@@ -52,111 +52,6 @@ void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_i
|
||||
Interwave>{});
|
||||
}
|
||||
|
||||
void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
|
||||
ck::Tuple<B0Layout>,
|
||||
ck::Tuple<D0Layout, B1Layout>,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
ck::Tuple<B0DataType>,
|
||||
ck::Tuple<D0DataType, B1DataType>,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
PassThrough,
|
||||
MultiplyAdd>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<
|
||||
ck::Tuple<B0Layout>,
|
||||
ck::Tuple<D0Layout, B1Layout>,
|
||||
ck::Tuple<B0DataType>,
|
||||
ck::Tuple<D0DataType, B1DataType>,
|
||||
PassThrough,
|
||||
MultiplyAdd,
|
||||
GemmMNKPadding,
|
||||
Interwave>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<
|
||||
ck::Tuple<B0Layout>,
|
||||
ck::Tuple<D0Layout, B1Layout>,
|
||||
ck::Tuple<B0DataType>,
|
||||
ck::Tuple<D0DataType, B1DataType>,
|
||||
PassThrough,
|
||||
MultiplyAdd,
|
||||
GemmMNKPadding,
|
||||
Interwave>{});
|
||||
}
|
||||
|
||||
void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
|
||||
ck::Tuple<B0Layout>,
|
||||
ck::Tuple<B1Layout>,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
ck::Tuple<B0DataType>,
|
||||
ck::Tuple<B1DataType>,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
PassThrough,
|
||||
Multiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<B0Layout>,
|
||||
ck::Tuple<B1Layout>,
|
||||
ck::Tuple<B0DataType>,
|
||||
ck::Tuple<B1DataType>,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
GemmMNKPadding,
|
||||
Interwave>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<B0Layout>,
|
||||
ck::Tuple<B1Layout>,
|
||||
ck::Tuple<B0DataType>,
|
||||
ck::Tuple<B1DataType>,
|
||||
PassThrough,
|
||||
Multiply,
|
||||
GemmMNKPadding,
|
||||
Interwave>{});
|
||||
}
|
||||
|
||||
void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
|
||||
ck::Tuple<B0Layout>,
|
||||
ck::Tuple<B1Layout>,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
ck::Tuple<B0DataType>,
|
||||
ck::Tuple<B1DataType>,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
PassThrough,
|
||||
MultiplyFastGelu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<B0Layout>,
|
||||
ck::Tuple<B1Layout>,
|
||||
ck::Tuple<B0DataType>,
|
||||
ck::Tuple<B1DataType>,
|
||||
PassThrough,
|
||||
MultiplyFastGelu,
|
||||
GemmMNKPadding,
|
||||
Interwave>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<B0Layout>,
|
||||
ck::Tuple<B1Layout>,
|
||||
ck::Tuple<B0DataType>,
|
||||
ck::Tuple<B1DataType>,
|
||||
PassThrough,
|
||||
MultiplyFastGelu,
|
||||
GemmMNKPadding,
|
||||
Interwave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
|
||||
|
||||
#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleABD<AsLayout,
|
||||
ck::Tuple<B0Layout>,
|
||||
ck::Tuple<B1Layout>,
|
||||
ELayout,
|
||||
AsDataType,
|
||||
ck::Tuple<B0DataType>,
|
||||
ck::Tuple<B1DataType>,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
PassThrough,
|
||||
MultiplyFastGelu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances<ck::Tuple<B0Layout>,
|
||||
ck::Tuple<B1Layout>,
|
||||
ck::Tuple<B0DataType>,
|
||||
ck::Tuple<B1DataType>,
|
||||
PassThrough,
|
||||
MultiplyFastGelu,
|
||||
GemmMNKPadding,
|
||||
Interwave>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances<ck::Tuple<B0Layout>,
|
||||
ck::Tuple<B1Layout>,
|
||||
ck::Tuple<B0DataType>,
|
||||
ck::Tuple<B1DataType>,
|
||||
PassThrough,
|
||||
MultiplyFastGelu,
|
||||
GemmMNKPadding,
|
||||
Interwave>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -6,7 +6,9 @@ set(GROUPED_CONV2D_BWD_WEIGHT
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp)
|
||||
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp
|
||||
)
|
||||
|
||||
if(DL_KERNELS)
|
||||
list(APPEND GROUPED_CONV2D_BWD_WEIGHT
|
||||
|
||||
@@ -10,7 +10,7 @@ namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -30,16 +30,9 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_in
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
ConvBwdWeightDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v2>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v5>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,12 +1,14 @@
|
||||
# XDL_DL_WMMA_KERNELS
|
||||
# XDL_DL_WMMA_KERNELS
|
||||
set(GROUPED_CONV3D_BWD_WEIGHT
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp)
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp
|
||||
)
|
||||
|
||||
if(DL_KERNELS)
|
||||
list(APPEND GROUPED_CONV3D_BWD_WEIGHT
|
||||
|
||||
@@ -10,7 +10,7 @@ namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -30,16 +30,9 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v2>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v5>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -88,7 +88,7 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification,
|
||||
|
||||
c_m_n_host_results.push_back(
|
||||
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n["
|
||||
<< i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
|
||||
|
||||
@@ -87,7 +87,7 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
|
||||
c_m_n_host_results.push_back(
|
||||
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n["
|
||||
<< i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
|
||||
|
||||
@@ -82,7 +82,7 @@ bool profile_grouped_gemm_tile_loop_impl(int do_verification,
|
||||
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
|
||||
c_m_n_host_results.push_back(
|
||||
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n["
|
||||
<< i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
|
||||
|
||||
@@ -88,7 +88,7 @@ bool profile_grouped_gemm_two_stage_impl(int do_verification,
|
||||
|
||||
c_m_n_host_results.push_back(
|
||||
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
|
||||
if(ck::EnvIsEnabled(ENV(CK_LOGGING)))
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n["
|
||||
<< i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i
|
||||
|
||||
36
pyproject.toml
Normal file
36
pyproject.toml
Normal file
@@ -0,0 +1,36 @@
|
||||
[build-system]
|
||||
requires = ["setuptools", "setuptools-scm"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "rocm-composable-kernel"
|
||||
dynamic = ["version"]
|
||||
description = "Composable Kernel, performance-critical kernels for machine learning workloads"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
license = {file = "LICENSE"}
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
]
|
||||
dependencies = []
|
||||
|
||||
[project.urls]
|
||||
"Homepage" = "https://github.com/rocm/composable_kernel"
|
||||
"Bug Tracker" = "https://github.com/rocm/composable_kernel/issues"
|
||||
|
||||
[tool.setuptools]
|
||||
packages = ["ck4inductor", "ck4inductor.include", "ck4inductor.library"]
|
||||
|
||||
[tool.setuptools.package-dir]
|
||||
ck4inductor = "python/ck4inductor"
|
||||
"ck4inductor.include" = "include"
|
||||
"ck4inductor.library" = "library"
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
"ck4inductor.include" = ["ck/**/*.hpp"]
|
||||
"ck4inductor.library" = ["src/tensor_operation_instance/gpu/gemm_universal/**/*.hpp"]
|
||||
|
||||
[tool.setuptools.dynamic]
|
||||
version = { attr = "setuptools_scm.get_version" }
|
||||
0
python/ck4inductor/__init__.py
Normal file
0
python/ck4inductor/__init__.py
Normal file
570
python/ck4inductor/universal_gemm/gen_instances.py
Normal file
570
python/ck4inductor/universal_gemm/gen_instances.py
Normal file
@@ -0,0 +1,570 @@
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from dataclasses import fields, replace
|
||||
from functools import lru_cache, partial
|
||||
from typing import List
|
||||
|
||||
from ..util import library_path
|
||||
|
||||
from .op import CKGemmOperation
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _ck_library_dir():
|
||||
gemm_instances_path = os.path.join(
|
||||
library_path(), "src", "tensor_operation_instance", "gpu", "gemm_universal"
|
||||
)
|
||||
if not os.path.exists(gemm_instances_path):
|
||||
log.error("CK library path %s does not exist", gemm_instances_path)
|
||||
return None
|
||||
return gemm_instances_path
|
||||
|
||||
|
||||
def parse_instances(str_instances: List[str]) -> List[CKGemmOperation]:
|
||||
"""
|
||||
Parse the lines containing Universal Gemm template instances into `CKGemmOperation` instances
|
||||
"""
|
||||
|
||||
def maybe_int(s):
|
||||
try:
|
||||
return int(s)
|
||||
except ValueError:
|
||||
return s
|
||||
|
||||
op_instances = []
|
||||
for line in str_instances:
|
||||
s_template_args = line.split("DeviceGemm_Xdl_CShuffleV3")[-1].strip("<>, ")
|
||||
template_args = []
|
||||
i_current = 0
|
||||
while i_current < len(s_template_args):
|
||||
if s_template_args[i_current] == " ":
|
||||
# skip whitespace
|
||||
i_current += 1
|
||||
continue
|
||||
elif s_template_args[i_current : i_current + 2] == "S<":
|
||||
# parse template S<Index...>
|
||||
i_next = s_template_args.find(">", i_current)
|
||||
template_args.append(
|
||||
tuple(map(int, s_template_args[i_current + 2 : i_next].split(",")))
|
||||
)
|
||||
i_current = i_next + 2
|
||||
else:
|
||||
# all string attributes must be either type aliases or global constants in C++
|
||||
i_next = s_template_args.find(",", i_current)
|
||||
template_args.append(
|
||||
maybe_int(
|
||||
s_template_args[i_current : i_next if i_next != -1 else None]
|
||||
)
|
||||
)
|
||||
if i_next != -1:
|
||||
i_current = i_next + 1
|
||||
if i_next == -1:
|
||||
break
|
||||
# pad with `None`s for the fields which are not defined in the instance
|
||||
new_instance = CKGemmOperation(
|
||||
*template_args, # type: ignore[arg-type]
|
||||
*((None,) * (len(fields(CKGemmOperation)) - len(template_args))),
|
||||
)
|
||||
# the last 2 template parameters are optional
|
||||
# if they are absent, substitute them with default values from Universal Gemm C++ template declaration
|
||||
if new_instance.a_compute_dtype is None:
|
||||
new_instance.a_compute_dtype = new_instance.c_element_dtype
|
||||
if new_instance.b_compute_dtype is None:
|
||||
new_instance.b_compute_dtype = new_instance.c_element_dtype
|
||||
|
||||
op_instances.append(new_instance)
|
||||
return op_instances
|
||||
|
||||
|
||||
def default_instances() -> List[CKGemmOperation]:
|
||||
# fallback: known working op instance for problem size M=2240 K=256 N=2048
|
||||
# all string attributes must be either type aliases or global constants in C++
|
||||
|
||||
return [
|
||||
CKGemmOperation(
|
||||
a_layout="Row",
|
||||
b_layout="Row",
|
||||
c_layout="Row",
|
||||
a_element_dtype="F16",
|
||||
b_element_dtype="F16",
|
||||
c_element_dtype="F16",
|
||||
a_compute_dtype="F16",
|
||||
b_compute_dtype="F16",
|
||||
acc_dtype="F32",
|
||||
c_shuffle_dtype="F16",
|
||||
a_elementwise_op="PassThrough",
|
||||
b_elementwise_op="PassThrough",
|
||||
c_elementwise_op="PassThrough",
|
||||
gemm_specialization="GemmSpecialization::Default",
|
||||
block_size=256,
|
||||
m_per_block=224,
|
||||
n_per_block=256,
|
||||
k_per_block=64,
|
||||
a_k1=8,
|
||||
b_k1=2,
|
||||
m_per_xdl=16,
|
||||
n_per_xdl=16,
|
||||
m_xdl_per_wave=7,
|
||||
n_xdl_per_wave=8,
|
||||
a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 32, 1),
|
||||
a_block_transfer_thread_cluster_arrange_order=(1, 0, 2),
|
||||
a_block_transfer_src_access_order=(1, 0, 2),
|
||||
a_block_transfer_src_vector_dim=2,
|
||||
a_block_transfer_src_scalar_per_vector=8,
|
||||
a_block_transfer_dst_scalar_per_vector_ak1=8,
|
||||
a_block_lds_extra_m=0, # type: ignore[arg-type]
|
||||
b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 32, 1),
|
||||
b_block_transfer_thread_cluster_arrange_order=(0, 2, 1),
|
||||
b_block_transfer_src_access_order=(0, 2, 1),
|
||||
b_block_transfer_src_vector_dim=1,
|
||||
b_block_transfer_src_scalar_per_vector=8,
|
||||
b_block_transfer_dst_scalar_per_vector_bk1=2,
|
||||
b_block_lds_extra_n=0, # type: ignore[arg-type]
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=2,
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
||||
1,
|
||||
32,
|
||||
1,
|
||||
8,
|
||||
),
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
|
||||
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
|
||||
block_gemm_pipeline_version="BlockGemmPipelineVersion::v3",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def gen_ops_library() -> List[CKGemmOperation]:
|
||||
"""
|
||||
Parse the Universal Gemm instances defined in the composable kernel library folder.
|
||||
"""
|
||||
ck_library_dir = _ck_library_dir()
|
||||
if not ck_library_dir:
|
||||
return []
|
||||
|
||||
grep_result = subprocess.run(
|
||||
[
|
||||
"grep",
|
||||
"-inR",
|
||||
"DeviceGemm_Xdl_CShuffleV3",
|
||||
_ck_library_dir(),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
op_instances = parse_instances(grep_result.stdout.strip().split("\n"))
|
||||
|
||||
log.debug("ck instances from library: %d", len(op_instances))
|
||||
|
||||
schedulers = [
|
||||
"BlockGemmPipelineScheduler::Intrawave",
|
||||
"BlockGemmPipelineScheduler::Interwave",
|
||||
]
|
||||
gemm_specs = [
|
||||
"GemmSpecialization::Default",
|
||||
"GemmSpecialization::MPadding",
|
||||
"GemmSpecialization::NPadding",
|
||||
"GemmSpecialization::KPadding",
|
||||
"GemmSpecialization::MNPadding",
|
||||
"GemmSpecialization::MKPadding",
|
||||
"GemmSpecialization::NKPadding",
|
||||
"GemmSpecialization::MNKPadding",
|
||||
]
|
||||
|
||||
# substitute templated args by looping through their domains
|
||||
substitute_instances = []
|
||||
for instance in op_instances:
|
||||
sub_scheduler = instance.block_gemm_pipeline_scheduler == "BlkGemmPipeSched"
|
||||
sub_spec = instance.gemm_specialization == "GemmSpec"
|
||||
schedulers_range = (
|
||||
schedulers if sub_scheduler else [instance.block_gemm_pipeline_scheduler]
|
||||
)
|
||||
spec_range = gemm_specs if sub_spec else [instance.gemm_specialization]
|
||||
for scheduler in schedulers_range:
|
||||
for spec in spec_range:
|
||||
substitute_instances.append(
|
||||
replace(
|
||||
instance,
|
||||
block_gemm_pipeline_scheduler=scheduler,
|
||||
gemm_specialization=spec,
|
||||
)
|
||||
)
|
||||
|
||||
return substitute_instances
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def gen_ops_preselected() -> List[CKGemmOperation]:
|
||||
"""
|
||||
Manually selected (through benchmarking) F16/F16/F16 Row/Col/Row instances
|
||||
"""
|
||||
ck_gemm_f16_rcr = partial(
|
||||
CKGemmOperation,
|
||||
a_layout="Row",
|
||||
b_layout="Col",
|
||||
c_layout="Row",
|
||||
a_element_dtype="F16",
|
||||
b_element_dtype="F16",
|
||||
c_element_dtype="F16",
|
||||
acc_dtype="F32",
|
||||
c_shuffle_dtype="F16",
|
||||
a_elementwise_op="PassThrough",
|
||||
b_elementwise_op="PassThrough",
|
||||
c_elementwise_op="PassThrough",
|
||||
k_per_block=64,
|
||||
a_k1=8,
|
||||
b_k1=8,
|
||||
a_block_transfer_thread_cluster_arrange_order=(1, 0, 2),
|
||||
a_block_transfer_src_access_order=(1, 0, 2),
|
||||
a_block_transfer_src_vector_dim=2,
|
||||
a_block_transfer_src_scalar_per_vector=8,
|
||||
a_block_transfer_dst_scalar_per_vector_ak1=8,
|
||||
a_block_lds_extra_m=0,
|
||||
b_block_transfer_thread_cluster_arrange_order=(1, 0, 2),
|
||||
b_block_transfer_src_access_order=(1, 0, 2),
|
||||
b_block_transfer_src_vector_dim=2,
|
||||
b_block_transfer_src_scalar_per_vector=8,
|
||||
b_block_transfer_dst_scalar_per_vector_bk1=8,
|
||||
b_block_lds_extra_n=0,
|
||||
a_compute_dtype="F16",
|
||||
b_compute_dtype="F16",
|
||||
)
|
||||
ck_gemm_f16_rcr_compute_friendly = partial(
|
||||
ck_gemm_f16_rcr,
|
||||
block_size=256,
|
||||
a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 32, 1),
|
||||
b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 32, 1),
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
||||
1,
|
||||
32,
|
||||
1,
|
||||
8,
|
||||
),
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
|
||||
)
|
||||
ck_gemm_f16_rcr_memory_friendly = partial(
|
||||
ck_gemm_f16_rcr,
|
||||
block_size=128,
|
||||
a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 16, 1),
|
||||
b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 16, 1),
|
||||
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Interwave",
|
||||
block_gemm_pipeline_version="BlockGemmPipelineVersion::v2",
|
||||
)
|
||||
ck_gemm_f16_rcr_latency_friendly = partial(
|
||||
ck_gemm_f16_rcr,
|
||||
gemm_specialization="GemmSpecialization::Default",
|
||||
block_size=128,
|
||||
m_per_xdl=16,
|
||||
n_per_xdl=16,
|
||||
m_xdl_per_wave=1,
|
||||
n_xdl_per_wave=1,
|
||||
a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 16, 1),
|
||||
b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 16, 1),
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block=4,
|
||||
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
|
||||
block_gemm_pipeline_version="BlockGemmPipelineVersion::v1",
|
||||
)
|
||||
return [
|
||||
ck_gemm_f16_rcr_compute_friendly(
|
||||
gemm_specialization="GemmSpecialization::MNKPadding",
|
||||
m_per_block=224,
|
||||
n_per_block=256,
|
||||
m_per_xdl=16,
|
||||
n_per_xdl=16,
|
||||
m_xdl_per_wave=7,
|
||||
n_xdl_per_wave=8,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=2,
|
||||
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
|
||||
block_gemm_pipeline_version="BlockGemmPipelineVersion::v3",
|
||||
),
|
||||
ck_gemm_f16_rcr_compute_friendly(
|
||||
gemm_specialization="GemmSpecialization::MNKPadding",
|
||||
m_per_block=128,
|
||||
n_per_block=128,
|
||||
m_per_xdl=32,
|
||||
n_per_xdl=32,
|
||||
m_xdl_per_wave=2,
|
||||
n_xdl_per_wave=2,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
|
||||
block_gemm_pipeline_version="BlockGemmPipelineVersion::v3",
|
||||
),
|
||||
ck_gemm_f16_rcr_compute_friendly(
|
||||
gemm_specialization="GemmSpecialization::MNKPadding",
|
||||
m_per_block=128,
|
||||
n_per_block=128,
|
||||
m_per_xdl=32,
|
||||
n_per_xdl=32,
|
||||
m_xdl_per_wave=2,
|
||||
n_xdl_per_wave=2,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
|
||||
block_gemm_pipeline_version="BlockGemmPipelineVersion::v4",
|
||||
),
|
||||
ck_gemm_f16_rcr_compute_friendly(
|
||||
gemm_specialization="GemmSpecialization::MNKPadding",
|
||||
m_per_block=128,
|
||||
n_per_block=128,
|
||||
m_per_xdl=32,
|
||||
n_per_xdl=32,
|
||||
m_xdl_per_wave=2,
|
||||
n_xdl_per_wave=2,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
|
||||
block_gemm_pipeline_version="BlockGemmPipelineVersion::v5",
|
||||
),
|
||||
ck_gemm_f16_rcr_compute_friendly(
|
||||
gemm_specialization="GemmSpecialization::Default",
|
||||
m_per_block=128,
|
||||
n_per_block=128,
|
||||
m_per_xdl=32,
|
||||
n_per_xdl=32,
|
||||
m_xdl_per_wave=2,
|
||||
n_xdl_per_wave=2,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
|
||||
block_gemm_pipeline_version="BlockGemmPipelineVersion::v3",
|
||||
),
|
||||
ck_gemm_f16_rcr_compute_friendly(
|
||||
gemm_specialization="GemmSpecialization::Default",
|
||||
m_per_block=128,
|
||||
n_per_block=128,
|
||||
m_per_xdl=32,
|
||||
n_per_xdl=32,
|
||||
m_xdl_per_wave=2,
|
||||
n_xdl_per_wave=2,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
|
||||
block_gemm_pipeline_version="BlockGemmPipelineVersion::v4",
|
||||
),
|
||||
ck_gemm_f16_rcr_compute_friendly(
|
||||
gemm_specialization="GemmSpecialization::Default",
|
||||
m_per_block=128,
|
||||
n_per_block=128,
|
||||
m_per_xdl=32,
|
||||
n_per_xdl=32,
|
||||
m_xdl_per_wave=2,
|
||||
n_xdl_per_wave=2,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave",
|
||||
block_gemm_pipeline_version="BlockGemmPipelineVersion::v5",
|
||||
),
|
||||
ck_gemm_f16_rcr_memory_friendly(
|
||||
gemm_specialization="GemmSpecialization::Default",
|
||||
m_per_block=16,
|
||||
n_per_block=32,
|
||||
m_per_xdl=16,
|
||||
n_per_xdl=16,
|
||||
m_xdl_per_wave=1,
|
||||
n_xdl_per_wave=1,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
||||
1,
|
||||
16,
|
||||
1,
|
||||
8,
|
||||
),
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block=4,
|
||||
),
|
||||
ck_gemm_f16_rcr_memory_friendly(
|
||||
gemm_specialization="GemmSpecialization::MNKPadding",
|
||||
m_per_block=16,
|
||||
n_per_block=32,
|
||||
m_per_xdl=16,
|
||||
n_per_xdl=16,
|
||||
m_xdl_per_wave=1,
|
||||
n_xdl_per_wave=1,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
||||
1,
|
||||
16,
|
||||
1,
|
||||
8,
|
||||
),
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block=4,
|
||||
),
|
||||
ck_gemm_f16_rcr_memory_friendly(
|
||||
gemm_specialization="GemmSpecialization::MNKPadding",
|
||||
m_per_block=16,
|
||||
n_per_block=64,
|
||||
m_per_xdl=16,
|
||||
n_per_xdl=16,
|
||||
m_xdl_per_wave=1,
|
||||
n_xdl_per_wave=2,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=2,
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
||||
1,
|
||||
16,
|
||||
1,
|
||||
8,
|
||||
),
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
|
||||
),
|
||||
ck_gemm_f16_rcr_memory_friendly(
|
||||
gemm_specialization="GemmSpecialization::MNKPadding",
|
||||
m_per_block=32,
|
||||
n_per_block=64,
|
||||
m_per_xdl=32,
|
||||
n_per_xdl=32,
|
||||
m_xdl_per_wave=1,
|
||||
n_xdl_per_wave=1,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
||||
1,
|
||||
16,
|
||||
1,
|
||||
8,
|
||||
),
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
|
||||
),
|
||||
ck_gemm_f16_rcr_memory_friendly(
|
||||
gemm_specialization="GemmSpecialization::MNKPadding",
|
||||
m_per_block=32,
|
||||
n_per_block=128,
|
||||
m_per_xdl=32,
|
||||
n_per_xdl=32,
|
||||
m_xdl_per_wave=1,
|
||||
n_xdl_per_wave=2,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
||||
1,
|
||||
16,
|
||||
1,
|
||||
8,
|
||||
),
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
|
||||
),
|
||||
ck_gemm_f16_rcr_memory_friendly(
|
||||
gemm_specialization="GemmSpecialization::Default",
|
||||
m_per_block=32,
|
||||
n_per_block=16,
|
||||
m_per_xdl=16,
|
||||
n_per_xdl=16,
|
||||
m_xdl_per_wave=1,
|
||||
n_xdl_per_wave=1,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
||||
1,
|
||||
32,
|
||||
1,
|
||||
4,
|
||||
),
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block=4,
|
||||
),
|
||||
ck_gemm_f16_rcr_memory_friendly(
|
||||
gemm_specialization="GemmSpecialization::MNKPadding",
|
||||
m_per_block=32,
|
||||
n_per_block=16,
|
||||
m_per_xdl=16,
|
||||
n_per_xdl=16,
|
||||
m_xdl_per_wave=1,
|
||||
n_xdl_per_wave=1,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
||||
1,
|
||||
32,
|
||||
1,
|
||||
4,
|
||||
),
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block=4,
|
||||
),
|
||||
ck_gemm_f16_rcr_memory_friendly(
|
||||
gemm_specialization="GemmSpecialization::MNKPadding",
|
||||
m_per_block=64,
|
||||
n_per_block=16,
|
||||
m_per_xdl=16,
|
||||
n_per_xdl=16,
|
||||
m_xdl_per_wave=2,
|
||||
n_xdl_per_wave=1,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=2,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
||||
1,
|
||||
64,
|
||||
1,
|
||||
2,
|
||||
),
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
|
||||
),
|
||||
ck_gemm_f16_rcr_memory_friendly(
|
||||
gemm_specialization="GemmSpecialization::MNKPadding",
|
||||
m_per_block=64,
|
||||
n_per_block=32,
|
||||
m_per_xdl=32,
|
||||
n_per_xdl=32,
|
||||
m_xdl_per_wave=1,
|
||||
n_xdl_per_wave=1,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
||||
1,
|
||||
32,
|
||||
1,
|
||||
4,
|
||||
),
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
|
||||
),
|
||||
ck_gemm_f16_rcr_memory_friendly(
|
||||
gemm_specialization="GemmSpecialization::MNKPadding",
|
||||
m_per_block=128,
|
||||
n_per_block=32,
|
||||
m_per_xdl=32,
|
||||
n_per_xdl=32,
|
||||
m_xdl_per_wave=2,
|
||||
n_xdl_per_wave=1,
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle=2,
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle=1,
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
||||
1,
|
||||
32,
|
||||
1,
|
||||
4,
|
||||
),
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block=8,
|
||||
),
|
||||
ck_gemm_f16_rcr_latency_friendly(
|
||||
m_per_block=16,
|
||||
n_per_block=32,
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
||||
1,
|
||||
16,
|
||||
1,
|
||||
8,
|
||||
),
|
||||
),
|
||||
ck_gemm_f16_rcr_latency_friendly(
|
||||
m_per_block=32,
|
||||
n_per_block=16,
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=(
|
||||
1,
|
||||
32,
|
||||
1,
|
||||
4,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(gen_ops_library())
|
||||
95
python/ck4inductor/universal_gemm/op.py
Normal file
95
python/ck4inductor/universal_gemm/op.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class CKGemmOperation:
|
||||
"""
|
||||
A python dataclass storing the template parameters of a CK Universal Gemm template instance
|
||||
"""
|
||||
|
||||
a_layout: str
|
||||
b_layout: str
|
||||
c_layout: str
|
||||
|
||||
a_element_dtype: str
|
||||
b_element_dtype: str
|
||||
c_element_dtype: str
|
||||
|
||||
acc_dtype: str
|
||||
c_shuffle_dtype: str
|
||||
|
||||
a_elementwise_op: str
|
||||
b_elementwise_op: str
|
||||
c_elementwise_op: str
|
||||
|
||||
gemm_specialization: str
|
||||
|
||||
block_size: int
|
||||
|
||||
m_per_block: int
|
||||
n_per_block: int
|
||||
k_per_block: int
|
||||
|
||||
a_k1: int
|
||||
b_k1: int
|
||||
|
||||
m_per_xdl: int
|
||||
n_per_xdl: int
|
||||
|
||||
m_xdl_per_wave: int
|
||||
n_xdl_per_wave: int
|
||||
|
||||
a_block_transfer_thread_cluster_lengths_ak0_m_ak1: Tuple[int, int, int]
|
||||
a_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int]
|
||||
a_block_transfer_src_access_order: Tuple[int, int, int]
|
||||
a_block_transfer_src_vector_dim: int
|
||||
a_block_transfer_src_scalar_per_vector: int
|
||||
a_block_transfer_dst_scalar_per_vector_ak1: int
|
||||
a_block_lds_extra_m: bool
|
||||
|
||||
b_block_transfer_thread_cluster_lengths_bk0_n_bk1: Tuple[int, int, int]
|
||||
b_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int]
|
||||
b_block_transfer_src_access_order: Tuple[int, int, int]
|
||||
|
||||
b_block_transfer_src_vector_dim: int
|
||||
b_block_transfer_src_scalar_per_vector: int
|
||||
b_block_transfer_dst_scalar_per_vector_bk1: int
|
||||
b_block_lds_extra_n: bool
|
||||
|
||||
c_shuffle_m_xdl_per_wave_per_shuffle: int
|
||||
c_shuffle_n_xdl_per_wave_per_shuffle: int
|
||||
|
||||
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block: (
|
||||
Tuple[int, int, int, int]
|
||||
)
|
||||
c_shuffle_block_transfer_scalar_per_vector_n_per_block: int
|
||||
|
||||
block_gemm_pipeline_scheduler: str
|
||||
block_gemm_pipeline_version: Optional[str]
|
||||
|
||||
a_compute_dtype: Optional[str]
|
||||
b_compute_dtype: Optional[str]
|
||||
|
||||
def name(self):
|
||||
# cpp alias for template instance
|
||||
return f"ck_devicegemm_xdl_shuffle_v3_{self.key_name()}"
|
||||
|
||||
def key_name(self):
|
||||
# TBD; must be unique per instance. Intended to use as dict key
|
||||
return "_".join(
|
||||
[
|
||||
"K"
|
||||
+ field_name.replace("_", "").lower()
|
||||
+ "V"
|
||||
+ (
|
||||
"x".join(map(str, iter(field_value)))
|
||||
if isinstance(field_value, tuple)
|
||||
else str(field_value).replace(":", "")
|
||||
)
|
||||
for field_name, field_value in self.dict_items()
|
||||
]
|
||||
)
|
||||
|
||||
def dict_items(self):
|
||||
return asdict(self).items()
|
||||
7
python/ck4inductor/util.py
Normal file
7
python/ck4inductor/util.py
Normal file
@@ -0,0 +1,7 @@
|
||||
import functools
|
||||
import os
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def library_path():
|
||||
return os.path.join(os.path.dirname(__file__), 'library')
|
||||
@@ -40,6 +40,13 @@ function(add_test_executable TEST_NAME)
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
if(INSTANCES_ONLY)
|
||||
set(TEST_TARGETS ${DEFAULT_GPU_TARGETS})
|
||||
else()
|
||||
set(TEST_TARGETS ${GPU_TARGETS})
|
||||
endif()
|
||||
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
|
||||
message("removing dl test ${source} ")
|
||||
@@ -47,20 +54,27 @@ function(add_test_executable TEST_NAME)
|
||||
endif()
|
||||
endforeach()
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "xdl")
|
||||
if(NOT TEST_TARGETS MATCHES "gfx9" AND source MATCHES "xdl")
|
||||
message("removing xdl test ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma")
|
||||
if(NOT TEST_TARGETS MATCHES "gfx11" AND source MATCHES "wmma")
|
||||
message("removing wmma test ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#only continue if there are some source files left on the list
|
||||
if(ARGN)
|
||||
if(ARGN MATCHES "_xdl")
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103)
|
||||
elseif(ARGN MATCHES "_wmma")
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
|
||||
endif()
|
||||
set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP)
|
||||
add_executable(${TEST_NAME} ${ARGN})
|
||||
set_property(TARGET ${TEST_NAME} PROPERTY HIP_ARCHITECTURES ${TEST_TARGETS} )
|
||||
target_link_libraries(${TEST_NAME} PRIVATE getopt::getopt)
|
||||
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}>)
|
||||
add_dependencies(tests ${TEST_NAME})
|
||||
@@ -105,6 +119,13 @@ function(add_gtest_executable TEST_NAME)
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
if(INSTANCES_ONLY)
|
||||
set(TEST_TARGETS ${DEFAULT_GPU_TARGETS})
|
||||
else()
|
||||
set(TEST_TARGETS ${GPU_TARGETS})
|
||||
endif()
|
||||
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
|
||||
message("removing dl test ${source} ")
|
||||
@@ -112,20 +133,27 @@ function(add_gtest_executable TEST_NAME)
|
||||
endif()
|
||||
endforeach()
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "xdl")
|
||||
if(NOT TEST_TARGETS MATCHES "gfx9" AND source MATCHES "xdl")
|
||||
message("removing xdl test ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma")
|
||||
if(NOT TEST_TARGETS MATCHES "gfx11" AND source MATCHES "wmma")
|
||||
message("removing wmma test ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#only continue if there are some source files left on the list
|
||||
if(ARGN)
|
||||
if(ARGN MATCHES "_xdl")
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103)
|
||||
elseif(ARGN MATCHES "_wmma")
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
|
||||
endif()
|
||||
set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP)
|
||||
add_executable(${TEST_NAME} ${ARGN})
|
||||
set_property(TARGET ${TEST_NAME} PROPERTY HIP_ARCHITECTURES ${TEST_TARGETS} )
|
||||
add_dependencies(tests ${TEST_NAME})
|
||||
add_dependencies(check ${TEST_NAME})
|
||||
|
||||
|
||||
@@ -32,19 +32,8 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
|
||||
std::vector<ck::utils::conv::ConvParam> conv_params;
|
||||
std::vector<ck::index_t> split_ks{1, 2};
|
||||
|
||||
bool skip_case(const ck::utils::conv::ConvParam& params, const ck::index_t split_k)
|
||||
bool skip_case(const ck::index_t split_k)
|
||||
{
|
||||
// Odd K or C values are supported only by DL and WMMA
|
||||
// kernels (only applies to fp16)
|
||||
// DL and WMMA kernels currently support only `split_k=1`
|
||||
if constexpr(std::is_same_v<InDataType, ck::half_t>)
|
||||
{
|
||||
if(split_k != 1 && (params.K_ % 2 != 0 || params.C_ % 2 != 0))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// 1d NWGC is only supported by DL kernel
|
||||
// DL kernel is only supported for split_k=1
|
||||
if constexpr(std::is_same_v<InLayout, NWGC> && std::is_same_v<OutLayout, NWGK>)
|
||||
@@ -100,7 +89,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
|
||||
{
|
||||
for(auto& param : conv_params)
|
||||
{
|
||||
if(!skip_case(param, split_k))
|
||||
if(!skip_case(split_k))
|
||||
{
|
||||
pass = pass && ck::profiler::profile_grouped_conv_bwd_weight_impl<NDimSpatial{},
|
||||
InLayout,
|
||||
@@ -189,6 +178,8 @@ TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D)
|
||||
this->conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{2, 16, 16, 1, 1, {3, 3}, {28, 28}, {2, 2}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->Run();
|
||||
}
|
||||
|
||||
@@ -207,5 +198,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D)
|
||||
{3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{3, 16, 16, 1, 1, {3, 3, 3}, {28, 28, 28}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->Run();
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user