mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 11:30:02 +00:00
GEMM Multi D for CK Tile Engine (#2660)
* Readme for GEMM Multi D
* GEMM Multi D partial Progress
* GEMM Multi D partial Progress!
* CK Tile Engine GEMM Multi D : All Python files generated
* Partial Progress
* Partial Progress
* Partial Progress
* Partial Progress : Incorrect Result
* Partial Progress : Debugging
* Partial Progress : Correct Results
* Partial Progress - Incorrect Results
* Partial Progress - Commenting Passthrough bypass logic
* Changing Passthrough to MultiplyMultiply
* Correct Results!
* Fix and debug the pass through feature
* Sample commit
* Correct Results : MultiplyMultiply
* Code Cleanup
* Removing Failed Instances
* Working code before Unary element support
* Custom Elementwise Function support and working implementation for Mul and Add
* Updating README
* Working for Passthrough
* Review Comments : Minor Fixes
* Review Comments : Minor Fixes
* Readme Updated
* Partial Changes after Rebase
* Working Code : Changes after Rebase
* Updating Jenkins file
* Removing default value changed while testing
* Configuration changes in config files
* Tile Handler changes in GEMM Multi D Tile Engine
* Tile Handler changes in GEMM Multi D Example
* Change log for Gemm Multi D in CK Tile Engine
* Configuration changes in config files
---------
Co-authored-by: ThomasNing <thomasning@amd.com>
[ROCm/composable_kernel commit: 3f57ec3d2d]
This commit is contained in:
committed by
GitHub
parent
3f5d6a4d1f
commit
62fb072dbe
@@ -262,219 +262,67 @@ struct PassThroughPack2
|
||||
|
||||
struct PassThrough
|
||||
{
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
|
||||
template <class T>
|
||||
using raw_t = std::remove_cv_t<std::remove_reference_t<T>>;
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<double, double>(double& y, const double& x) const
|
||||
template <class Y, class X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y&& y, const X& x) const
|
||||
{
|
||||
y = x;
|
||||
/* Only do the assignment when
|
||||
- y is an *l-value* and
|
||||
- y is *not* const */
|
||||
if constexpr(std::is_lvalue_reference_v<Y&&> && !std::is_const_v<raw_t<Y>>)
|
||||
{
|
||||
y = ck_tile::type_convert<raw_t<Y>>(x);
|
||||
}
|
||||
/* otherwise (r-value or const) → do nothing */
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, double>(float& y, const double& x) const
|
||||
template <typename E, typename C, typename... Ds>
|
||||
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
// Suppress unused parameter warning for ds
|
||||
((void)ds, ...);
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<double, float>(double& y, const float& x) const
|
||||
{
|
||||
y = type_convert<double>(x);
|
||||
// Just assign e with c
|
||||
if constexpr(std::is_same_v<E, C>)
|
||||
{
|
||||
e = c;
|
||||
}
|
||||
else
|
||||
{
|
||||
e = ck_tile::type_convert<E>(c);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
|
||||
struct MultiDMultiply
|
||||
{
|
||||
template <typename E, typename C, typename... Ds>
|
||||
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
// Start with the base value c
|
||||
float result = ck_tile::type_convert<float>(c);
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y, const ck_tile::fp16_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
// Multiply by each D parameter using fold expression
|
||||
((result *= ck_tile::type_convert<float>(ds)), ...);
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp16_t, float>(ck_tile::fp16_t& y,
|
||||
const float& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::fp16_t>(x);
|
||||
e = ck_tile::type_convert<E>(result);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::bf16_t, ck_tile::bf16_t>(ck_tile::bf16_t& y, const ck_tile::bf16_t& x) const
|
||||
struct MultiDAdd
|
||||
{
|
||||
template <typename E, typename C, typename... Ds>
|
||||
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
// Start with the base value c
|
||||
float result = ck_tile::type_convert<float>(c);
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int32_t, int32_t>(int32_t& y, const int32_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
// Add by each D parameter using fold expression
|
||||
((result += ck_tile::type_convert<float>(ds)), ...);
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf16_t, float>(ck_tile::bf16_t& y,
|
||||
const float& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::bf16_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::bf16_t>(float& y,
|
||||
const ck_tile::bf16_t& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::bf16_t, ck_tile::fp16_t>(ck_tile::bf16_t& y, const ck_tile::fp16_t& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::bf16_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::fp16_t>(float& y,
|
||||
const ck_tile::fp16_t& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp16_t, int8_t>(ck_tile::fp16_t& y,
|
||||
const int8_t& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::fp16_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf16_t, int8_t>(ck_tile::bf16_t& y,
|
||||
const int8_t& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::bf16_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<uint8_t, uint8_t>(uint8_t& y, const uint8_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int8_t, int32_t>(int8_t& y, const int32_t& x) const
|
||||
{
|
||||
y = type_convert<int8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int32_t, int8_t>(int32_t& y, const int8_t& x) const
|
||||
{
|
||||
y = type_convert<int32_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int8_t, float>(int8_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<int8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, int8_t>(float& y, const int8_t& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int4_t, int>(int4_t& y, const int& x) const
|
||||
{
|
||||
y = type_convert<int4_t>(x);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::fp8_t, ck_tile::fp8_t>(ck_tile::fp8_t& y, const ck_tile::fp8_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::fp8_t>(float& y,
|
||||
const ck_tile::fp8_t& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp8_t, float>(ck_tile::fp8_t& y,
|
||||
const float& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::fp8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::fp16_t, ck_tile::fp8_t>(ck_tile::fp16_t& y, const ck_tile::fp8_t& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::fp16_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::fp8_t, ck_tile::fp16_t>(ck_tile::fp8_t& y, const ck_tile::fp16_t& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::fp8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::bf8_t, ck_tile::bf8_t>(ck_tile::bf8_t& y, const ck_tile::bf8_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::bf8_t>(float& y,
|
||||
const ck_tile::bf8_t& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf8_t, float>(ck_tile::bf8_t& y,
|
||||
const float& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::bf8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::fp16_t, ck_tile::bf8_t>(ck_tile::fp16_t& y, const ck_tile::bf8_t& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::fp16_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::bf8_t, ck_tile::fp16_t>(ck_tile::bf8_t& y, const ck_tile::fp16_t& x) const
|
||||
{
|
||||
y = ck_tile::type_convert<ck_tile::bf8_t>(x);
|
||||
e = ck_tile::type_convert<E>(result);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -6,10 +6,10 @@
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce2d.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
#include "ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp"
|
||||
#include "ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp"
|
||||
#include "ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
Reference in New Issue
Block a user