mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
NHWC conv 2d: fwd bfp16/int8, Device level tuning and host API (#73)
* add fwd bf16 conv
* change tunning parametor
* add int8 for conv fwd
* remove comments
* change tunning parametor for int8
* change init int8 example
* add test for conv2d fwd
* change device operation file pos because merge develop
* fwd int8 use reference
* test_conv_fwd use reference
* add braket for if statement
* rename fwd example name
* remove StaticBufferOfVectorTypeV2
* tweak example
Co-authored-by: ltqin <letaoqin@amd.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
[ROCm/composable_kernel commit: 880fbee957]
This commit is contained in:
@@ -86,10 +86,9 @@ struct ReferenceConvFwd : public device::BaseOperator
|
||||
float v_wei;
|
||||
|
||||
arg.in_element_op_(
|
||||
v_in,
|
||||
static_cast<const float>(arg.in_n_c_hi_wi_(n, c, hi, wi)));
|
||||
v_in, ck::type_convert<float>(arg.in_n_c_hi_wi_(n, c, hi, wi)));
|
||||
arg.wei_element_op_(
|
||||
v_wei, static_cast<const float>(arg.wei_k_c_y_x_(k, c, y, x)));
|
||||
v_wei, ck::type_convert<float>(arg.wei_k_c_y_x_(k, c, y, x)));
|
||||
|
||||
v_acc += v_in * v_wei;
|
||||
}
|
||||
@@ -101,7 +100,7 @@ struct ReferenceConvFwd : public device::BaseOperator
|
||||
|
||||
arg.out_element_op_(v_out, v_acc);
|
||||
|
||||
arg.out_n_k_ho_wo_(n, k, ho, wo) = v_out;
|
||||
arg.out_n_k_ho_wo_(n, k, ho, wo) = ck::type_convert<OutDataType>(v_out);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
|
||||
Reference in New Issue
Block a user