mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Single-kernel GEMM + layernorm (#263)
* dump lds content in appropriate precision type * add squared add reduction op; allows sq sum * initial stub from regular gemm impl * layernorm example code & host verification * initial layernorm implementation * tidy up * make C0 precision type consistent with C * clang-tidy and additional comments * tighten up example code * account for extra flops/bytes from normalization * clang-format * c0 bias/beta/gamma now have its own precision type * AccElemOp for gemm outputs prior to feeding to layernorm * update workgroup mapping * rename kernel template param to reflect its dual use * use LDS mem pool for reduction workspace * change cshuffle precision type to f16; clean up * clang-format * correct naming * explicit cast * fully implemented gemm + bias + activation + add + norm * activation in correct order * reflect reduction API's recent change * amend * clean up; add comment * keep up with recent changes in reduction API * format * resolve merge conflicts Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -12,21 +12,27 @@ template <typename T, typename Enable = void>
|
||||
struct PrintAsType;
|
||||
|
||||
template <typename T>
|
||||
struct PrintAsType<T, typename std::enable_if<std::is_floating_point<T>::value>::value>
|
||||
struct PrintAsType<T, typename std::enable_if<std::is_floating_point<T>::value>::type>
|
||||
{
|
||||
using type = float;
|
||||
__host__ __device__ static void Print(const T& p) { printf("%.3f ", static_cast<type>(p)); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PrintAsType<ck::half_t, void>
|
||||
{
|
||||
using type = float;
|
||||
__host__ __device__ static void Print(const ck::half_t& p)
|
||||
{
|
||||
printf("%.3f ", static_cast<type>(p));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct PrintAsType<T, typename std::enable_if<std::is_integral<T>::value>::value>
|
||||
struct PrintAsType<T, typename std::enable_if<std::is_integral<T>::value>::type>
|
||||
{
|
||||
using type = int;
|
||||
__host__ __device__ static void Print(const T& p) { printf("%d ", static_cast<type>(p)); }
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
@@ -41,7 +47,6 @@ struct PrintAsType<T, typename std::enable_if<std::is_integral<T>::value>::value
|
||||
template <typename T, index_t element_stride = 1, index_t row_bytes = 128>
|
||||
__device__ void print_shared(T const* p_shared, index_t num_elements)
|
||||
{
|
||||
using PrintType = typename detail::PrintAsType<T>::type;
|
||||
constexpr index_t row_elements = row_bytes / sizeof(T);
|
||||
static_assert((element_stride >= 1 && element_stride <= row_elements),
|
||||
"element_stride should between [1, row_elements]");
|
||||
@@ -63,7 +68,7 @@ __device__ void print_shared(T const* p_shared, index_t num_elements)
|
||||
printf("elem %5d: ", i);
|
||||
for(index_t j = 0; j < row_elements; j += element_stride)
|
||||
{
|
||||
printf("%.0f ", static_cast<PrintType>(p_shared[i + j]));
|
||||
detail::PrintAsType<T>::Print(p_shared[i + j]);
|
||||
}
|
||||
|
||||
printf("\n");
|
||||
|
||||
Reference in New Issue
Block a user