mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Add examples of Gemm (data type: int4) (#367)
* Add GEMM examples for int4 Currently the source files are just copied from int8 examples * Re-use pre-defined alias in int4 exmples * Distinguish user-side type from kernel-side type * Add int4_t support for check_err() * Allow conversion between Tensor<> specializations * Re-format source files * Use different type for host tensors * Re-use CopyAsType<>() to implement copy ctor * Re-use element-wise operation type alias * Fix typo in alias names * Complete the int4 examples * Add constraint to Tensor<> templated methods * Add type traits 'is_signed_integral<>' * Add type constraints for integer version check_err<>() * Allow comparing different-sized integral types in check_err() * Check converted Tensor<int4_t> with golden Tensor<int8_t> * Remove constraint of Tensor<>::CopyAsType() * Avoid compilation error while disabling ck::int4_t support * Remove debug messages * Add #error directive to prevent compile sources with wrong setting * Simplify tensor usages in examples * Add constraint to check_err() input reference type * Align design with other PR * Use ""_uz to simplify example code * Avoid too much generalizing check_err() * Re-format GEMM instance template arguments * Extract int4 example common codes * Sort include directives * Move #include directives into new header * Move common codes together * Re-format template argument in example code * Reuse same implementation code for most of GEMM examples * Re-format common.hpp * Unify structured comment in examples * Use reinterpret_cast<>() for cross-type pointer conversion * Revert "Add type traits 'is_signed_integral<>'" This reverts commitf2c148efae. * Allow unsigned integer arguments for check_err() * Fix compilation error in check_err() * Remove unnecessary copy ctor for Tensor<> * Mark Tensor<> special member functions as 'default' * Use more strict condition to add code in examples * Fix wrong program return value of GEMM examples * Handle the case while user specify all the strides * Fix never-ran examples * Exit successfully if GEMM instance does not support given problem * Add missing 'else' keyword * Re-format CMakeLists.txt * Add wrapper function to hide value conversion while copying memory * Add new DeviceMem API to copy memory * Use new DeviceMem API to implement examples * Revert "Add new DeviceMem API to copy memory" This reverts commit3f190b0779. * Add conversion ctor for Tensor<> * Write Tensor<> conversion logics explicitly in example code * Convert Tensor<> values after transfer data to host
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/type.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -164,7 +165,7 @@ check_err(const std::vector<T>& out,
|
||||
{
|
||||
if(out.size() != ref.size())
|
||||
{
|
||||
std::cout << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
|
||||
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
@@ -185,8 +186,7 @@ check_err(const std::vector<T>& out,
|
||||
err_count++;
|
||||
if(err_count < 5)
|
||||
{
|
||||
std::cout << msg << " out[" << i << "] != ref[" << i
|
||||
<< "]: " << static_cast<int>(out[i]) << " != " << static_cast<int>(ref[i])
|
||||
std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r
|
||||
<< std::endl;
|
||||
}
|
||||
res = false;
|
||||
|
||||
Reference in New Issue
Block a user