mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
NHWC conv 2d: fwd bfp16/int8, Device level tuning and host API (#73)
* add fwd bf16 conv * change tunning parametor * add int8 for conv fwd * remove comments * change tunning parametor for int8 * change init int8 example * add test for conv2d fwd * change device operation file pos because merge develop * fwd int8 use reference * test_conv_fwd use reference * add braket for if statement * rename fwd example name * remove StaticBufferOfVectorTypeV2 * tweak example Co-authored-by: ltqin <letaoqin@amd.com> Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -86,10 +86,9 @@ struct ReferenceConvFwd : public device::BaseOperator
|
||||
float v_wei;
|
||||
|
||||
arg.in_element_op_(
|
||||
v_in,
|
||||
static_cast<const float>(arg.in_n_c_hi_wi_(n, c, hi, wi)));
|
||||
v_in, ck::type_convert<float>(arg.in_n_c_hi_wi_(n, c, hi, wi)));
|
||||
arg.wei_element_op_(
|
||||
v_wei, static_cast<const float>(arg.wei_k_c_y_x_(k, c, y, x)));
|
||||
v_wei, ck::type_convert<float>(arg.wei_k_c_y_x_(k, c, y, x)));
|
||||
|
||||
v_acc += v_in * v_wei;
|
||||
}
|
||||
@@ -101,7 +100,7 @@ struct ReferenceConvFwd : public device::BaseOperator
|
||||
|
||||
arg.out_element_op_(v_out, v_acc);
|
||||
|
||||
arg.out_n_k_ho_wo_(n, k, ho, wo) = v_out;
|
||||
arg.out_n_k_ho_wo_(n, k, ho, wo) = ck::type_convert<OutDataType>(v_out);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
|
||||
Reference in New Issue
Block a user