mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
GEMM Multi D for CK Tile Engine (#2660)
* Readme for GEMM Multi D * GEMM Multi D partial Progress * GEMM Multi D partial Progress! * CK Tile Engine GEMM Multi D : All Python files generated * Partial Progress * Partial Progress * Partial Progress * Partial Progress : Incorrect Result * Partial Progress : Debugging * Partial Progress : Correct Results * Partial Progress - Incorrect Results * Partial Progress - Commenting Passthrough bypass logic * Changing Passthrough to MultiplyMultiply * Correct Results! * Fix and debug the pass through feature * Sample commit * Correct Results : MultiplyMultiply * Code Cleanup * Removing Failed Instances * Working code before Unary element support * Custom Elementwise Function support and working implementation for Mul and Add * Updating README * Working for Passthrough * Review Comments : Minor Fixes * Review Comments : Minor Fixes * Readme Updated * Partial Changes after Rebase * Working Code : Changes after Rebase * Updating Jenkins file * Removing default value changed while testing * Configuration changes in config files * Tile Handler changes in GEMM Multi D Tile Engine * Tile Handler changes in GEMM Multi D Example * Change log for Gemm Multi D in CK Tile Engine * Configuration changes in config files --------- Co-authored-by: ThomasNing <thomasning@amd.com>
This commit is contained in:
committed by
GitHub
parent
30dafe8281
commit
3f57ec3d2d
@@ -262,219 +262,67 @@ struct PassThroughPack2
|
||||
|
||||
struct PassThrough
|
||||
{
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
|
||||
template <class T>
|
||||
using raw_t = std::remove_cv_t<std::remove_reference_t<T>>;
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<double, double>(double& y, const double& x) const
|
||||
template <class Y, class X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y&& y, const X& x) const
|
||||
{
|
||||
y = x;
|
||||
/* Only do the assignment when
|
||||
- y is an *l-value* and
|
||||
- y is *not* const */
|
||||
if constexpr(std::is_lvalue_reference_v<Y&&> && !std::is_const_v<raw_t<Y>>)
|
||||
{
|
||||
y = ck_tile::type_convert<raw_t<Y>>(x);
|
||||
}
|
||||
/* otherwise (r-value or const) → do nothing */
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, double>(float& y, const double& x) const
|
||||
template <typename E, typename C, typename... Ds>
|
||||
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
// Suppress unused parameter warning for ds
|
||||
((void)ds, ...);
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<double, float>(double& y, const float& x) const
|
||||
{
|
||||
y = type_convert<double>(x);
|
||||
// Just assign e with c
|
||||
if constexpr(std::is_same_v<E, C>)
|
||||
{
|
||||
e = c;
|
||||
}
|
||||
else
|
||||
{
|
||||
e = ck_tile::type_convert<E>(c);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
|
||||
struct MultiDMultiply
|
||||
{
|
||||
template <typename E, typename C, typename... Ds>
|
||||
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
// Start with the base value c
|
||||
float result = ck_tile::type_convert<float>(c);
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y, const ck_tile::fp16_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
// Multiply by each D parameter using fold expression
|
||||
((result *= ck_tile::type_convert<float>(ds)), ...);
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp16_t, float>(ck_tile::fp16_t& y,
|
||||
const float& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::fp16_t>(x);
|
||||
e = ck_tile::type_convert<E>(result);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::bf16_t, ck_tile::bf16_t>(ck_tile::bf16_t& y, const ck_tile::bf16_t& x) const
|
||||
struct MultiDAdd
|
||||
{
|
||||
template <typename E, typename C, typename... Ds>
|
||||
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
// Start with the base value c
|
||||
float result = ck_tile::type_convert<float>(c);
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int32_t, int32_t>(int32_t& y, const int32_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
// Add by each D parameter using fold expression
|
||||
((result += ck_tile::type_convert<float>(ds)), ...);
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf16_t, float>(ck_tile::bf16_t& y,
|
||||
const float& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::bf16_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::bf16_t>(float& y,
|
||||
const ck_tile::bf16_t& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::bf16_t, ck_tile::fp16_t>(ck_tile::bf16_t& y, const ck_tile::fp16_t& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::bf16_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::fp16_t>(float& y,
|
||||
const ck_tile::fp16_t& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp16_t, int8_t>(ck_tile::fp16_t& y,
|
||||
const int8_t& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::fp16_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf16_t, int8_t>(ck_tile::bf16_t& y,
|
||||
const int8_t& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::bf16_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<uint8_t, uint8_t>(uint8_t& y, const uint8_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int8_t, int32_t>(int8_t& y, const int32_t& x) const
|
||||
{
|
||||
y = type_convert<int8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int32_t, int8_t>(int32_t& y, const int8_t& x) const
|
||||
{
|
||||
y = type_convert<int32_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int8_t, float>(int8_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<int8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, int8_t>(float& y, const int8_t& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int4_t, int>(int4_t& y, const int& x) const
|
||||
{
|
||||
y = type_convert<int4_t>(x);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::fp8_t, ck_tile::fp8_t>(ck_tile::fp8_t& y, const ck_tile::fp8_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::fp8_t>(float& y,
|
||||
const ck_tile::fp8_t& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp8_t, float>(ck_tile::fp8_t& y,
|
||||
const float& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::fp8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::fp16_t, ck_tile::fp8_t>(ck_tile::fp16_t& y, const ck_tile::fp8_t& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::fp16_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::fp8_t, ck_tile::fp16_t>(ck_tile::fp8_t& y, const ck_tile::fp16_t& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::fp8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::bf8_t, ck_tile::bf8_t>(ck_tile::bf8_t& y, const ck_tile::bf8_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::bf8_t>(float& y,
|
||||
const ck_tile::bf8_t& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf8_t, float>(ck_tile::bf8_t& y,
|
||||
const float& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::bf8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::fp16_t, ck_tile::bf8_t>(ck_tile::fp16_t& y, const ck_tile::bf8_t& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::fp16_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::bf8_t, ck_tile::fp16_t>(ck_tile::bf8_t& y, const ck_tile::fp16_t& x) const
|
||||
{
|
||||
y = ck_tile::type_convert<ck_tile::bf8_t>(x);
|
||||
e = ck_tile::type_convert<E>(result);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user