mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 01:36:06 +00:00
use constant tensor descriptor
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
#pragma once
|
||||
#include "device_tensor.cuh"
|
||||
#include "device_tensor_descriptor.cuh"
|
||||
|
||||
template <class TFloat,
|
||||
unsigned NWorkLen0,
|
||||
@@ -13,7 +13,7 @@ __device__ void blockwise_4d_tensor_op(const DeviceTensorDescriptor<4>& src_desc
|
||||
TFloat* __restrict__ p_dst,
|
||||
F f)
|
||||
{
|
||||
#if 1
|
||||
#if 0
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
printf("blockwise_4d_tensor_op: 0: \t"
|
||||
@@ -80,7 +80,7 @@ __device__ void blockwise_4d_tensor_op(const DeviceTensorDescriptor<4>& src_desc
|
||||
|
||||
f(p_src[dindex], p_dst[sindex]);
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
// if(threadIdx.x == 0)
|
||||
{
|
||||
printf("blockwise_4d_tensor_op: 1: thread id %u, \t"
|
||||
@@ -106,7 +106,7 @@ __device__ void threadwise_4d_tensor_op(const DeviceTensorDescriptor<4>& src_des
|
||||
TFloat* __restrict__ p_dst,
|
||||
F f)
|
||||
{
|
||||
#if 1
|
||||
#if 0
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
printf("threadwise_4d_tensor_op: 0: \t"
|
||||
@@ -151,7 +151,7 @@ __device__ void threadwise_4d_tensor_op(const DeviceTensorDescriptor<4>& src_des
|
||||
|
||||
f(p_src[sindex], p_dst[dindex]);
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
printf("threadwise_4d_tensor_op: 1: thread id %u, \t"
|
||||
@@ -178,7 +178,7 @@ __device__ void threadwise_direct_convolution(const DeviceTensorDescriptor<4>& i
|
||||
const DeviceTensorDescriptor<4>& out_desc,
|
||||
TFloat* __restrict__ p_out)
|
||||
{
|
||||
#if 1
|
||||
#if 0
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
printf("threadwise_direct_convolution: 0: \t"
|
||||
@@ -212,7 +212,7 @@ __device__ void threadwise_direct_convolution(const DeviceTensorDescriptor<4>& i
|
||||
out_desc.GetStride(2),
|
||||
out_desc.GetStride(3));
|
||||
}
|
||||
#elif 1
|
||||
#elif 0
|
||||
{
|
||||
printf("threadwise_direct_convolution: 0: \t"
|
||||
"threadIdx.x %u \t"
|
||||
@@ -275,7 +275,7 @@ __device__ void threadwise_direct_convolution(const DeviceTensorDescriptor<4>& i
|
||||
|
||||
p_out[out_index] += p_wei[wei_index] * p_in[in_index];
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
printf("threadwise_direct_convolution: 1: \t"
|
||||
@@ -320,7 +320,7 @@ __device__ void blockwise_convolution(const DeviceTensorDescriptor<4>& in_desc,
|
||||
const DeviceTensorDescriptor<4>& out_desc,
|
||||
TFloat* __restrict__ p_out)
|
||||
{
|
||||
#if 1
|
||||
#if 0
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
printf("blockwise_convolution: 0: \t"
|
||||
@@ -501,7 +501,7 @@ __global__ void gridwise_convolution(const DeviceTensorDescriptor<4> in_desc,
|
||||
const DeviceTensorDescriptor<4> out_desc,
|
||||
TFloat* __restrict__ p_out)
|
||||
{
|
||||
#if 1
|
||||
#if 0
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
printf("gridwise_convolution: 0: \t"
|
||||
|
||||
Reference in New Issue
Block a user