From 421d45a7a76e53946e7441212af9d08a0d93ff68 Mon Sep 17 00:00:00 2001 From: gingerBill Date: Thu, 26 May 2022 18:06:26 +0100 Subject: [PATCH] Add `intrinsics.fused_mul_add` --- core/intrinsics/intrinsics.odin | 2 ++ core/simd/simd.odin | 3 ++ src/check_builtin.cpp | 53 +++++++++++++++++++++++++++++++++ src/checker_builtin_procs.hpp | 2 ++ src/llvm_backend_proc.cpp | 25 ++++++++++++++++ 5 files changed, 85 insertions(+) diff --git a/core/intrinsics/intrinsics.odin b/core/intrinsics/intrinsics.odin index bf8f56e63..c13e099c5 100644 --- a/core/intrinsics/intrinsics.odin +++ b/core/intrinsics/intrinsics.odin @@ -35,6 +35,8 @@ overflow_mul :: proc(lhs, rhs: $T) -> (T, bool) #optional_ok --- sqrt :: proc(x: $T) -> T where type_is_float(T) --- +fused_mul_add :: proc(a, b, c: $T) -> T where type_is_float(T) || (type_is_simd_vector(T) && type_is_float(type_elem_type(T))) --- + mem_copy :: proc(dst, src: rawptr, len: int) --- mem_copy_non_overlapping :: proc(dst, src: rawptr, len: int) --- mem_zero :: proc(ptr: rawptr, len: int) --- diff --git a/core/simd/simd.odin b/core/simd/simd.odin index 17d97f918..ce278bce7 100644 --- a/core/simd/simd.odin +++ b/core/simd/simd.odin @@ -109,6 +109,9 @@ count_zeros :: intrinsics.count_zeros count_trailing_zeros :: intrinsics.count_trailing_zeros count_leading_zeros :: intrinsics.count_leading_zeros +fused_mul_add :: intrinsics.fused_mul_add +fma :: intrinsics.fused_mul_add + to_array_ptr :: #force_inline proc "contextless" (v: ^#simd[$LANES]$E) -> ^[LANES]E { return (^[LANES]E)(v) } diff --git a/src/check_builtin.cpp b/src/check_builtin.cpp index ee805702d..19b78b46e 100644 --- a/src/check_builtin.cpp +++ b/src/check_builtin.cpp @@ -3681,6 +3681,59 @@ bool check_builtin_procedure(CheckerContext *c, Operand *operand, Ast *call, i32 } break; + case BuiltinProc_fused_mul_add: + { + Operand x = {}; + Operand y = {}; + Operand z = {}; + check_expr(c, &x, ce->args[0]); if (x.mode == Addressing_Invalid) return false; + check_expr(c, &y, ce->args[1]); if (y.mode == Addressing_Invalid) return false; + check_expr(c, &z, ce->args[2]); if (z.mode == Addressing_Invalid) return false; + + convert_to_typed(c, &y, x.type); if (y.mode == Addressing_Invalid) return false; + convert_to_typed(c, &x, y.type); if (x.mode == Addressing_Invalid) return false; + convert_to_typed(c, &z, x.type); if (z.mode == Addressing_Invalid) return false; + convert_to_typed(c, &x, z.type); if (x.mode == Addressing_Invalid) return false; + if (is_type_untyped(x.type)) { + gbString xts = type_to_string(x.type); + error(x.expr, "Expected a typed floating point value or #simd vector for '%.*s', got %s", LIT(builtin_name), xts); + gb_string_free(xts); + return false; + } + + Type *elem = core_array_type(x.type); + if (!is_type_float(x.type) && !(is_type_simd_vector(x.type) && is_type_float(elem))) { + gbString xts = type_to_string(x.type); + error(x.expr, "Expected a floating point or #simd vector value for '%.*s', got %s", LIT(builtin_name), xts); + gb_string_free(xts); + return false; + } + if (is_type_different_to_arch_endianness(elem)) { + GB_ASSERT(elem->kind == Type_Basic); + if (elem->Basic.flags & (BasicFlag_EndianLittle|BasicFlag_EndianBig)) { + gbString xts = type_to_string(x.type); + error(x.expr, "Expected a float which does not specify the explicit endianness for '%.*s', got %s", LIT(builtin_name), xts); + gb_string_free(xts); + return false; + } + } + + if (!are_types_identical(x.type, y.type) || !are_types_identical(y.type, z.type)) { + gbString xts = type_to_string(x.type); + gbString yts = type_to_string(y.type); + gbString zts = type_to_string(z.type); + error(x.expr, "Mismatched types for '%.*s', got %s vs %s vs %s", LIT(builtin_name), xts, yts, zts); + gb_string_free(zts); + gb_string_free(yts); + gb_string_free(xts); + return false; + } + + operand->mode = Addressing_Value; + operand->type = default_type(x.type); + } + break; + case BuiltinProc_mem_copy: case BuiltinProc_mem_copy_non_overlapping: { diff --git a/src/checker_builtin_procs.hpp b/src/checker_builtin_procs.hpp index 1b2c105f1..5859ce3ab 100644 --- a/src/checker_builtin_procs.hpp +++ b/src/checker_builtin_procs.hpp @@ -65,6 +65,7 @@ enum BuiltinProcId { BuiltinProc_overflow_mul, BuiltinProc_sqrt, + BuiltinProc_fused_mul_add, BuiltinProc_mem_copy, BuiltinProc_mem_copy_non_overlapping, @@ -348,6 +349,7 @@ gb_global BuiltinProc builtin_procs[BuiltinProc_COUNT] = { {STR_LIT("overflow_mul"), 2, false, Expr_Expr, BuiltinProcPkg_intrinsics}, {STR_LIT("sqrt"), 1, false, Expr_Expr, BuiltinProcPkg_intrinsics}, + {STR_LIT("fused_mul_add"), 3, false, Expr_Expr, BuiltinProcPkg_intrinsics}, {STR_LIT("mem_copy"), 3, false, Expr_Stmt, BuiltinProcPkg_intrinsics}, {STR_LIT("mem_copy_non_overlapping"), 3, false, Expr_Stmt, BuiltinProcPkg_intrinsics}, diff --git a/src/llvm_backend_proc.cpp b/src/llvm_backend_proc.cpp index 4c4000fec..a5dda7815 100644 --- a/src/llvm_backend_proc.cpp +++ b/src/llvm_backend_proc.cpp @@ -2005,6 +2005,31 @@ lbValue lb_build_builtin_proc(lbProcedure *p, Ast *expr, TypeAndValue const &tv, return res; } + case BuiltinProc_fused_mul_add: + { + Type *type = tv.type; + lbValue x = lb_emit_conv(p, lb_build_expr(p, ce->args[0]), type); + lbValue y = lb_emit_conv(p, lb_build_expr(p, ce->args[1]), type); + lbValue z = lb_emit_conv(p, lb_build_expr(p, ce->args[2]), type); + + + char const *name = "llvm.fma"; + LLVMTypeRef types[1] = {lb_type(p->module, type)}; + unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name)); + GB_ASSERT_MSG(id != 0, "Unable to find %s.%s", name, LLVMPrintTypeToString(types[0])); + LLVMValueRef ip = LLVMGetIntrinsicDeclaration(p->module->mod, id, types, gb_count_of(types)); + + LLVMValueRef args[3] = {}; + args[0] = x.value; + args[1] = y.value; + args[2] = z.value; + + lbValue res = {}; + res.value = LLVMBuildCall(p->builder, ip, args, gb_count_of(args), ""); + res.type = type; + return res; + } + case BuiltinProc_mem_copy: { lbValue dst = lb_build_expr(p, ce->args[0]);