diff --git a/core/crypto/sm3/sm3.odin b/core/crypto/sm3/sm3.odin index 017cf2246..d12b254f7 100644 --- a/core/crypto/sm3/sm3.odin +++ b/core/crypto/sm3/sm3.odin @@ -10,11 +10,11 @@ package sm3 Implementation of the SM3 hashing algorithm, as defined in */ +import "core:encoding/endian" import "core:io" +import "core:math/bits" import "core:os" -import "../util" - /* High level API */ @@ -110,6 +110,9 @@ init :: proc(ctx: ^Sm3_Context) { ctx.state[5] = IV[5] ctx.state[6] = IV[6] ctx.state[7] = IV[7] + + ctx.length = 0 + ctx.bitlength = 0 } update :: proc(ctx: ^Sm3_Context, data: []byte) { @@ -119,14 +122,14 @@ update :: proc(ctx: ^Sm3_Context, data: []byte) { if ctx.bitlength > 0 { n := copy(ctx.x[ctx.bitlength:], data[:]) ctx.bitlength += u64(n) - if ctx.bitlength == 64 { + if ctx.bitlength == BLOCK_SIZE { block(ctx, ctx.x[:]) ctx.bitlength = 0 } data = data[n:] } - if len(data) >= 64 { - n := len(data) &~ (64 - 1) + if len(data) >= BLOCK_SIZE { + n := len(data) &~ (BLOCK_SIZE - 1) block(ctx, data[:n]) data = data[n:] } @@ -138,45 +141,44 @@ update :: proc(ctx: ^Sm3_Context, data: []byte) { final :: proc(ctx: ^Sm3_Context, hash: []byte) { length := ctx.length - pad: [64]byte + pad: [BLOCK_SIZE]byte pad[0] = 0x80 - if length % 64 < 56 { - update(ctx, pad[0:56 - length % 64]) + if length % BLOCK_SIZE < 56 { + update(ctx, pad[0:56 - length % BLOCK_SIZE]) } else { - update(ctx, pad[0:64 + 56 - length % 64]) + update(ctx, pad[0:BLOCK_SIZE + 56 - length % BLOCK_SIZE]) } length <<= 3 - util.PUT_U64_BE(pad[:], length) + endian.unchecked_put_u64be(pad[:], length) update(ctx, pad[0:8]) assert(ctx.bitlength == 0) - util.PUT_U32_BE(hash[0:], ctx.state[0]) - util.PUT_U32_BE(hash[4:], ctx.state[1]) - util.PUT_U32_BE(hash[8:], ctx.state[2]) - util.PUT_U32_BE(hash[12:], ctx.state[3]) - util.PUT_U32_BE(hash[16:], ctx.state[4]) - util.PUT_U32_BE(hash[20:], ctx.state[5]) - util.PUT_U32_BE(hash[24:], ctx.state[6]) - util.PUT_U32_BE(hash[28:], ctx.state[7]) + for i := 0; i < DIGEST_SIZE / 4; i += 1 { + endian.unchecked_put_u32be(hash[i * 4:], ctx.state[i]) + } } /* SM3 implementation */ +BLOCK_SIZE :: 64 + Sm3_Context :: struct { state: [8]u32, - x: [64]byte, + x: [BLOCK_SIZE]byte, bitlength: u64, length: u64, } +@(private) IV := [8]u32 { 0x7380166f, 0x4914b2b9, 0x172442d7, 0xda8a0600, 0xa96f30bc, 0x163138aa, 0xe38dee4d, 0xb0fb0e4e, } +@(private) block :: proc "contextless" (ctx: ^Sm3_Context, buf: []byte) { buf := buf @@ -186,20 +188,18 @@ block :: proc "contextless" (ctx: ^Sm3_Context, buf: []byte) { state0, state1, state2, state3 := ctx.state[0], ctx.state[1], ctx.state[2], ctx.state[3] state4, state5, state6, state7 := ctx.state[4], ctx.state[5], ctx.state[6], ctx.state[7] - for len(buf) >= 64 { + for len(buf) >= BLOCK_SIZE { for i := 0; i < 16; i += 1 { - j := i * 4 - w[i] = - u32(buf[j]) << 24 | u32(buf[j + 1]) << 16 | u32(buf[j + 2]) << 8 | u32(buf[j + 3]) + w[i] = endian.unchecked_get_u32be(buf[i * 4:]) } for i := 16; i < 68; i += 1 { - p1v := w[i - 16] ~ w[i - 9] ~ util.ROTL32(w[i - 3], 15) + p1v := w[i - 16] ~ w[i - 9] ~ bits.rotate_left32(w[i - 3], 15) // @note(zh): inlined P1 w[i] = p1v ~ - util.ROTL32(p1v, 15) ~ - util.ROTL32(p1v, 23) ~ - util.ROTL32(w[i - 13], 7) ~ + bits.rotate_left32(p1v, 15) ~ + bits.rotate_left32(p1v, 23) ~ + bits.rotate_left32(w[i - 13], 7) ~ w[i - 6] } for i := 0; i < 64; i += 1 { @@ -210,8 +210,8 @@ block :: proc "contextless" (ctx: ^Sm3_Context, buf: []byte) { e, f, g, h := state4, state5, state6, state7 for i := 0; i < 16; i += 1 { - v1 := util.ROTL32(u32(a), 12) - ss1 := util.ROTL32(v1 + u32(e) + util.ROTL32(0x79cc4519, i), 7) + v1 := bits.rotate_left32(u32(a), 12) + ss1 := bits.rotate_left32(v1 + u32(e) + bits.rotate_left32(0x79cc4519, i), 7) ss2 := ss1 ~ v1 // @note(zh): inlined FF1 @@ -219,15 +219,18 @@ block :: proc "contextless" (ctx: ^Sm3_Context, buf: []byte) { // @note(zh): inlined GG1 tt2 := u32(e ~ f ~ g) + u32(h) + ss1 + w[i] - a, b, c, d = tt1, a, util.ROTL32(u32(b), 9), c + a, b, c, d = tt1, a, bits.rotate_left32(u32(b), 9), c // @note(zh): inlined P0 e, f, g, h = - (tt2 ~ util.ROTL32(tt2, 9) ~ util.ROTL32(tt2, 17)), e, util.ROTL32(u32(f), 19), g + (tt2 ~ bits.rotate_left32(tt2, 9) ~ bits.rotate_left32(tt2, 17)), + e, + bits.rotate_left32(u32(f), 19), + g } for i := 16; i < 64; i += 1 { - v := util.ROTL32(u32(a), 12) - ss1 := util.ROTL32(v + u32(e) + util.ROTL32(0x7a879d8a, i % 32), 7) + v := bits.rotate_left32(u32(a), 12) + ss1 := bits.rotate_left32(v + u32(e) + bits.rotate_left32(0x7a879d8a, i % 32), 7) ss2 := ss1 ~ v // @note(zh): inlined FF2 @@ -235,10 +238,13 @@ block :: proc "contextless" (ctx: ^Sm3_Context, buf: []byte) { // @note(zh): inlined GG2 tt2 := u32(((e & f) | ((~e) & g)) + h) + ss1 + w[i] - a, b, c, d = tt1, a, util.ROTL32(u32(b), 9), c + a, b, c, d = tt1, a, bits.rotate_left32(u32(b), 9), c // @note(zh): inlined P0 e, f, g, h = - (tt2 ~ util.ROTL32(tt2, 9) ~ util.ROTL32(tt2, 17)), e, util.ROTL32(u32(f), 19), g + (tt2 ~ bits.rotate_left32(tt2, 9) ~ bits.rotate_left32(tt2, 17)), + e, + bits.rotate_left32(u32(f), 19), + g } state0 ~= a @@ -250,7 +256,7 @@ block :: proc "contextless" (ctx: ^Sm3_Context, buf: []byte) { state6 ~= g state7 ~= h - buf = buf[64:] + buf = buf[BLOCK_SIZE:] } ctx.state[0], ctx.state[1], ctx.state[2], ctx.state[3] = state0, state1, state2, state3