mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 03:49:41 +00:00
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
#include "ConstantTensorDescriptor.hip.hpp"
|
||||
#include "threadwise_nd_tensor_op.hip.hpp"
|
||||
|
||||
// optimized for scenario if p_in, p_wei, p_out are in register
|
||||
template <class TInWei, class TOut, class InDesc, class WeiDesc, class OutDesc>
|
||||
@@ -84,10 +85,12 @@ __device__ void threadwise_direct_convolution_2(InDesc,
|
||||
TInWei p_wei_reg[wei_reg_desc.GetElementSpace()];
|
||||
|
||||
// copy input tensor into register
|
||||
threadwise_4d_tensor_copy(in_desc, p_in, in_reg_desc, p_in_reg, in_reg_desc.GetLengths());
|
||||
threadwise_nd_tensor_copy(
|
||||
in_desc, p_in, in_reg_desc, p_in_reg, in_reg_desc.GetLengths(), Number<1>{});
|
||||
|
||||
// copy input tensor into register
|
||||
threadwise_4d_tensor_copy(wei_desc, p_wei, wei_reg_desc, p_wei_reg, wei_reg_desc.GetLengths());
|
||||
threadwise_nd_tensor_copy(
|
||||
wei_desc, p_wei, wei_reg_desc, p_wei_reg, wei_reg_desc.GetLengths(), Number<1>{});
|
||||
|
||||
// do convolution
|
||||
threadwise_direct_convolution_1(
|
||||
|
||||
Reference in New Issue
Block a user