From 10e4de3c015de149008d0dc573e391af99d82574 Mon Sep 17 00:00:00 2001 From: gingerBill Date: Wed, 25 May 2022 22:04:47 +0100 Subject: [PATCH] Add `intrinsics.simd_reduce_*` --- core/simd/simd.odin | 8 +++ src/check_builtin.cpp | 50 +++++++++++++++++ src/checker_builtin_procs.hpp | 16 ++++++ src/llvm_backend_proc.cpp | 102 +++++++++++++++++++++++++++++++--- 4 files changed, 167 insertions(+), 9 deletions(-) diff --git a/core/simd/simd.odin b/core/simd/simd.odin index 87386f91f..ad14855bd 100644 --- a/core/simd/simd.odin +++ b/core/simd/simd.odin @@ -34,6 +34,14 @@ ge :: intrinsics.simd_ge extract :: intrinsics.simd_extract replace :: intrinsics.simd_replace +reduce_add_ordered :: intrinsics.simd_reduce_add_ordered +reduce_mul_ordered :: intrinsics.simd_reduce_mul_ordered +reduce_min :: intrinsics.simd_reduce_min +reduce_max :: intrinsics.simd_reduce_max +reduce_and :: intrinsics.simd_reduce_and +reduce_or :: intrinsics.simd_reduce_or +reduce_xor :: intrinsics.simd_reduce_xor + splat :: #force_inline proc "contextless" ($T: typeid/#simd[$LANES]$E, value: E) -> T { return T{0..args[0]); if (x.mode == Addressing_Invalid) { return false; } + + if (!is_type_simd_vector(x.type)) { + error(x.expr, "'%.*s' expected a simd vector type", LIT(builtin_name)); + return false; + } + Type *elem = base_array_type(x.type); + if (!is_type_integer(elem) && !is_type_float(elem)) { + gbString xs = type_to_string(x.type); + error(x.expr, "'%.*s' expected a #simd type with an integer or floating-point element, got '%s'", LIT(builtin_name), xs); + gb_string_free(xs); + return false; + } + + operand->mode = Addressing_Value; + operand->type = base_array_type(x.type); + return true; + } + + case BuiltinProc_simd_reduce_and: + case BuiltinProc_simd_reduce_or: + case BuiltinProc_simd_reduce_xor: + { + Operand x = {}; + check_expr(c, &x, ce->args[0]); if (x.mode == Addressing_Invalid) { return false; } + + if (!is_type_simd_vector(x.type)) { + error(x.expr, "'%.*s' expected a simd vector type", LIT(builtin_name)); + return false; + } + Type *elem = base_array_type(x.type); + if (!is_type_integer(elem)) { + gbString xs = type_to_string(x.type); + error(x.expr, "'%.*s' expected a #simd type with an integer element, got '%s'", LIT(builtin_name), xs); + gb_string_free(xs); + return false; + } + + operand->mode = Addressing_Value; + operand->type = base_array_type(x.type); + return true; + } + + // case BuiltinProc_simd_rotate_left: // { // Operand x = {}; diff --git a/src/checker_builtin_procs.hpp b/src/checker_builtin_procs.hpp index f5d4111bc..98cc9f284 100644 --- a/src/checker_builtin_procs.hpp +++ b/src/checker_builtin_procs.hpp @@ -149,6 +149,14 @@ BuiltinProc__simd_begin, BuiltinProc_simd_extract, BuiltinProc_simd_replace, + + BuiltinProc_simd_reduce_add_ordered, + BuiltinProc_simd_reduce_mul_ordered, + BuiltinProc_simd_reduce_min, + BuiltinProc_simd_reduce_max, + BuiltinProc_simd_reduce_and, + BuiltinProc_simd_reduce_or, + BuiltinProc_simd_reduce_xor, BuiltinProc__simd_end, // Platform specific intrinsics @@ -401,6 +409,14 @@ gb_global BuiltinProc builtin_procs[BuiltinProc_COUNT] = { {STR_LIT("simd_extract"), 2, false, Expr_Expr, BuiltinProcPkg_intrinsics}, {STR_LIT("simd_replace"), 3, false, Expr_Expr, BuiltinProcPkg_intrinsics}, + + {STR_LIT("simd_reduce_add_ordered"), 1, false, Expr_Expr, BuiltinProcPkg_intrinsics}, + {STR_LIT("simd_reduce_mul_ordered"), 1, false, Expr_Expr, BuiltinProcPkg_intrinsics}, + {STR_LIT("simd_reduce_min"), 1, false, Expr_Expr, BuiltinProcPkg_intrinsics}, + {STR_LIT("simd_reduce_max"), 1, false, Expr_Expr, BuiltinProcPkg_intrinsics}, + {STR_LIT("simd_reduce_and"), 1, false, Expr_Expr, BuiltinProcPkg_intrinsics}, + {STR_LIT("simd_reduce_or"), 1, false, Expr_Expr, BuiltinProcPkg_intrinsics}, + {STR_LIT("simd_reduce_xor"), 1, false, Expr_Expr, BuiltinProcPkg_intrinsics}, {STR_LIT(""), 0, false, Expr_Stmt, BuiltinProcPkg_intrinsics}, diff --git a/src/llvm_backend_proc.cpp b/src/llvm_backend_proc.cpp index cfb69c654..c09265e7a 100644 --- a/src/llvm_backend_proc.cpp +++ b/src/llvm_backend_proc.cpp @@ -981,7 +981,7 @@ lbValue lb_emit_call(lbProcedure *p, lbValue value, Array const &args, return result; } -lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAndValue const &tv, BuiltinProcId id) { +lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAndValue const &tv, BuiltinProcId builtin_id) { ast_node(ce, CallExpr, expr); lbModule *m = p->module; @@ -1000,7 +1000,7 @@ lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAndValue const LLVMOpcode op_code = cast(LLVMOpcode)0; - switch (id) { + switch (builtin_id) { case BuiltinProc_simd_add: case BuiltinProc_simd_sub: case BuiltinProc_simd_mul: @@ -1008,14 +1008,14 @@ lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAndValue const case BuiltinProc_simd_rem: arg1 = lb_build_expr(p, ce->args[1]); if (is_float) { - switch (id) { + switch (builtin_id) { case BuiltinProc_simd_add: op_code = LLVMFAdd; break; case BuiltinProc_simd_sub: op_code = LLVMFSub; break; case BuiltinProc_simd_mul: op_code = LLVMFMul; break; case BuiltinProc_simd_div: op_code = LLVMFDiv; break; } } else { - switch (id) { + switch (builtin_id) { case BuiltinProc_simd_add: op_code = LLVMAdd; break; case BuiltinProc_simd_sub: op_code = LLVMSub; break; case BuiltinProc_simd_mul: op_code = LLVMMul; break; @@ -1053,7 +1053,7 @@ lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAndValue const Type *elem1 = base_array_type(arg1.type); bool is_masked = false; - switch (id) { + switch (builtin_id) { case BuiltinProc_simd_shl: op_code = LLVMShl; is_masked = false; break; case BuiltinProc_simd_shr: op_code = is_signed ? LLVMAShr : LLVMLShr; is_masked = false; break; case BuiltinProc_simd_shl_masked: op_code = LLVMShl; is_masked = true; break; @@ -1086,7 +1086,7 @@ lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAndValue const case BuiltinProc_simd_or: case BuiltinProc_simd_xor: arg1 = lb_build_expr(p, ce->args[1]); - switch (id) { + switch (builtin_id) { case BuiltinProc_simd_and: op_code = LLVMAnd; break; case BuiltinProc_simd_or: op_code = LLVMOr; break; case BuiltinProc_simd_xor: op_code = LLVMXor; break; @@ -1144,7 +1144,7 @@ lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAndValue const arg1 = lb_build_expr(p, ce->args[1]); if (is_float) { LLVMRealPredicate pred = cast(LLVMRealPredicate)0; - switch (id) { + switch (builtin_id) { case BuiltinProc_simd_eq: pred = LLVMRealOEQ; break; case BuiltinProc_simd_ne: pred = LLVMRealONE; break; case BuiltinProc_simd_lt: pred = LLVMRealOLT; break; @@ -1159,7 +1159,7 @@ lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAndValue const } } else { LLVMIntPredicate pred = cast(LLVMIntPredicate)0; - switch (id) { + switch (builtin_id) { case BuiltinProc_simd_eq: pred = LLVMIntEQ; break; case BuiltinProc_simd_ne: pred = LLVMIntNE; break; case BuiltinProc_simd_lt: pred = is_signed ? LLVMIntSLT :LLVMIntULT; break; @@ -1184,8 +1184,92 @@ lbValue lb_build_builtin_simd_proc(lbProcedure *p, Ast *expr, TypeAndValue const arg2 = lb_build_expr(p, ce->args[2]); res.value = LLVMBuildInsertElement(p->builder, arg0.value, arg2.value, arg1.value, ""); return res; + + case BuiltinProc_simd_reduce_add_ordered: + case BuiltinProc_simd_reduce_mul_ordered: + { + LLVMTypeRef llvm_elem = lb_type(m, elem); + LLVMValueRef args[2] = {}; + isize args_count = 0; + + char const *name = nullptr; + switch (builtin_id) { + case BuiltinProc_simd_reduce_add_ordered: + if (is_float) { + name = "llvm.vector.reduce.fadd"; + args[args_count++] = LLVMConstReal(llvm_elem, 0.0); + } else { + name = "llvm.vector.reduce.add"; + } + break; + case BuiltinProc_simd_reduce_mul_ordered: + if (is_float) { + name = "llvm.vector.reduce.fmul"; + args[args_count++] = LLVMConstReal(llvm_elem, 1.0); + } else { + name = "llvm.vector.reduce.mul"; + } + break; + } + args[args_count++] = arg0.value; + + + LLVMTypeRef types[1] = {lb_type(p->module, arg0.type)}; + unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name)); + GB_ASSERT_MSG(id != 0, "Unable to find %s.%s", name, LLVMPrintTypeToString(types[0])); + LLVMValueRef ip = LLVMGetIntrinsicDeclaration(p->module->mod, id, types, gb_count_of(types)); + + lbValue res = {}; + res.value = LLVMBuildCall(p->builder, ip, args, cast(unsigned)args_count, ""); + res.type = tv.type; + return res; + } + case BuiltinProc_simd_reduce_min: + case BuiltinProc_simd_reduce_max: + case BuiltinProc_simd_reduce_and: + case BuiltinProc_simd_reduce_or: + case BuiltinProc_simd_reduce_xor: + { + char const *name = nullptr; + switch (builtin_id) { + case BuiltinProc_simd_reduce_min: + if (is_float) { + name = "llvm.vector.reduce.fmin"; + } else if (is_signed) { + name = "llvm.vector.reduce.smin"; + } else { + name = "llvm.vector.reduce.umin"; + } + break; + case BuiltinProc_simd_reduce_max: + if (is_float) { + name = "llvm.vector.reduce.fmax"; + } else if (is_signed) { + name = "llvm.vector.reduce.smax"; + } else { + name = "llvm.vector.reduce.umax"; + } + break; + case BuiltinProc_simd_reduce_and: name = "llvm.vector.reduce.and"; break; + case BuiltinProc_simd_reduce_or: name = "llvm.vector.reduce.or"; break; + case BuiltinProc_simd_reduce_xor: name = "llvm.vector.reduce.xor"; break; + } + LLVMTypeRef types[1] = {lb_type(p->module, arg0.type)}; + unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name)); + GB_ASSERT_MSG(id != 0, "Unable to find %s.%s", name, LLVMPrintTypeToString(types[0])); + LLVMValueRef ip = LLVMGetIntrinsicDeclaration(p->module->mod, id, types, gb_count_of(types)); + + LLVMValueRef args[1] = {}; + args[0] = arg0.value; + + lbValue res = {}; + res.value = LLVMBuildCall(p->builder, ip, args, gb_count_of(args), ""); + res.type = tv.type; + return res; + } } - GB_PANIC("Unhandled simd intrinsic: '%.*s'", LIT(builtin_procs[id].name)); + GB_PANIC("Unhandled simd intrinsic: '%.*s'", LIT(builtin_procs[builtin_id].name)); + return {}; }