mirror of
https://github.com/Ed94/Odin.git
synced 2026-06-21 05:05:00 -07:00
Add builtin outer_product
This commit is contained in:
@@ -2017,6 +2017,66 @@ bool check_builtin_procedure(CheckerContext *c, Operand *operand, Ast *call, i32
|
||||
operand->type = check_matrix_type_hint(operand->type, type_hint);
|
||||
break;
|
||||
}
|
||||
|
||||
case BuiltinProc_outer_product: {
|
||||
Operand x = {};
|
||||
Operand y = {};
|
||||
check_expr(c, &x, ce->args[0]);
|
||||
if (x.mode == Addressing_Invalid) {
|
||||
return false;
|
||||
}
|
||||
check_expr(c, &y, ce->args[1]);
|
||||
if (y.mode == Addressing_Invalid) {
|
||||
return false;
|
||||
}
|
||||
if (!is_operand_value(x) || !is_operand_value(y)) {
|
||||
error(call, "'%.*s' expects only arrays", LIT(builtin_name));
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!is_type_array(x.type) && !is_type_array(y.type)) {
|
||||
gbString s1 = type_to_string(x.type);
|
||||
gbString s2 = type_to_string(y.type);
|
||||
error(call, "'%.*s' expects only arrays, got %s and %s", LIT(builtin_name), s1, s2);
|
||||
gb_string_free(s2);
|
||||
gb_string_free(s1);
|
||||
return false;
|
||||
}
|
||||
|
||||
Type *xt = base_type(x.type);
|
||||
Type *yt = base_type(y.type);
|
||||
GB_ASSERT(xt->kind == Type_Array);
|
||||
GB_ASSERT(yt->kind == Type_Array);
|
||||
if (!are_types_identical(xt->Array.elem, yt->Array.elem)) {
|
||||
gbString s1 = type_to_string(xt->Array.elem);
|
||||
gbString s2 = type_to_string(yt->Array.elem);
|
||||
error(call, "'%.*s' mismatched element types, got %s vs %s", LIT(builtin_name), s1, s2);
|
||||
gb_string_free(s2);
|
||||
gb_string_free(s1);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (xt->Array.count == 0 || yt->Array.count == 0) {
|
||||
gbString s1 = type_to_string(x.type);
|
||||
gbString s2 = type_to_string(y.type);
|
||||
error(call, "'%.*s' expects only arrays of non-zero length, got %s and %s", LIT(builtin_name), s1, s2);
|
||||
gb_string_free(s2);
|
||||
gb_string_free(s1);
|
||||
return false;
|
||||
}
|
||||
|
||||
i64 max_count = xt->Array.count*yt->Array.count;
|
||||
if (max_count > MAX_MATRIX_ELEMENT_COUNT) {
|
||||
error(call, "Product of the array lengths exceed the maximum matrix element count, got %d, expected a maximum of %d", cast(int)max_count, MAX_MATRIX_ELEMENT_COUNT);
|
||||
return false;
|
||||
}
|
||||
|
||||
operand->mode = Addressing_Value;
|
||||
operand->type = alloc_type_matrix(xt->Array.elem, xt->Array.count, yt->Array.count);
|
||||
operand->type = check_matrix_type_hint(operand->type, type_hint);
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
case BuiltinProc_simd_vector: {
|
||||
Operand x = {};
|
||||
|
||||
Reference in New Issue
Block a user