/*
 * Copyright (C) 2013, 2014 Giorgio Vazzana
 *
 * This file is part of Seren.
 *
 * Seren is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Seren is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <string.h>
#include "sha256.h"
#include "rw.h"

static const uint32_t k[64] = {
0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5,
0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3,
0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc,
0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7,
0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13,
0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3,
0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5,
0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208,
0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2
};

#define RR(x, c) ((x) >> (c) | (x) << (32 - (c)))

#define F_S0(a) (RR(a, 2) ^ RR(a, 13) ^ RR(a, 22))
#define F_S1(e) (RR(e, 6) ^ RR(e, 11) ^ RR(e, 25))
#define F_CH(e, f, g) ((e & f) ^ (~e & g))
#define F_MA(a, b, c) ((a & b) ^ (a & c) ^ (b & c))

static void sha256_blocks(uint32_t H[8], const uint8_t *buf, size_t nblocks)
{
	size_t n;
	uint32_t a, b, c, d, e, f, g, h;
	uint32_t i, w[64];

	/* Process the message in successive 512-bit (64 bytes) chunks */
	for (n = 0; n < nblocks; n++) {

		/* Create a 64-entry message schedule array w[0..63] of 32-bit words: */
		/* 1 - Copy chunk into first 16 words w[0..15] of the message schedule array */
		for (i = 0; i < 16; i++)
			w[i] = read_be32(buf + n * 64 + i * 4);
		/* 2 - Extend the first 16 words into the remaining 48 words w[16..63] of the message schedule array */
		for (i = 16; i < 64; i++) {
			uint32_t s0, s1;

			s0   = RR(w[i-15],  7) ^ RR(w[i-15], 18) ^ (w[i-15] >>  3);
			s1   = RR(w[i- 2], 17) ^ RR(w[i- 2], 19) ^ (w[i- 2] >> 10);
			w[i] = w[i-16] + s0 + w[i-7] + s1;
		}

		/* Initialize working variables to current hash value */
		a = H[0];
		b = H[1];
		c = H[2];
		d = H[3];
		e = H[4];
		f = H[5];
		g = H[6];
		h = H[7];

#ifdef SMALL
		/* Main loop */
		for (i = 0; i < 64; i++) {
			uint32_t s0, s1, ch, ma, tmp1, tmp2;

			s0 = F_S0(a);
			s1 = F_S1(e);
			ch = F_CH(e, f, g);
			ma = F_MA(a, b, c);
			tmp1 = h + s1 + ch + k[i] + w[i];
			tmp2 = s0 + ma;

			h = g;
			g = f;
			f = e;
			e = d + tmp1;
			d = c;
			c = b;
			b = a;
			a = tmp1 + tmp2;
		}
#else

#define OP(i, a, b, c, d, e, f, g, h)                  \
	do {                                               \
		h = h + F_S1(e) + F_CH(e, f, g) + k[i] + w[i]; \
		d = d + h;                                     \
		h = h + F_S0(a) + F_MA(a, b, c);               \
	} while (0)

#define OP8(i) \
	OP(i,   a, b, c, d, e, f, g, h); OP(i+1, h, a, b, c, d, e, f, g); \
	OP(i+2, g, h, a, b, c, d, e, f); OP(i+3, f, g, h, a, b, c, d, e); \
	OP(i+4, e, f, g, h, a, b, c, d); OP(i+5, d, e, f, g, h, a, b, c); \
	OP(i+6, c, d, e, f, g, h, a, b); OP(i+7, b, c, d, e, f, g, h, a)

		/* Main loop */
		OP8( 0);
		OP8( 8);
		OP8(16);
		OP8(24);
		OP8(32);
		OP8(40);
		OP8(48);
		OP8(56);
#endif

		/* Add the compressed chunk to the current hash value */
		H[0] = H[0] + a;
		H[1] = H[1] + b;
		H[2] = H[2] + c;
		H[3] = H[3] + d;
		H[4] = H[4] + e;
		H[5] = H[5] + f;
		H[6] = H[6] + g;
		H[7] = H[7] + h;
	}
}

void sha256_buffer(const uint8_t *buf, size_t len, uint8_t digest[32])
{
	uint8_t  fillbuf[128];
	size_t   n, nblocks;
	uint64_t bitlen;
	uint32_t h[8] = { 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19 };

	bitlen  = len * 8;
	nblocks = len / 64;
	n       = len % 64;

	/* process the bulk of the data (whole blocks) */
	sha256_blocks(h, buf, nblocks);

	/* copy the rest */
	memcpy(fillbuf, buf + nblocks * 64, n);

	/* insert padding, append length */
	fillbuf[n++] = 0x80;
	while (n % 64 != 56)
		fillbuf[n++] = 0;
	write_be64(fillbuf + n, bitlen);
	n += 8;

	/* process last blocks */
	nblocks = n / 64;
	sha256_blocks(h, fillbuf, nblocks);

	/* fill in sha256 hash */
	for(n = 0; n < 8; n++)
		write_be32(digest + 4 * n, h[n]);
}

#ifdef SELFTEST
#include <stdio.h>
#include <stdlib.h>
#include <sys/time.h>

#define BUFDIM (10*1024*1024)
#define RUNS   3

struct tv {
	char *input;
	char *hexdigest;
};

int main(int argc, char *argv[])
{
	struct tv tv[] = {
		{"", "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"},
		{"a", "ca978112ca1bbdcafac231b39a23dc4da786eff8147c4e72b9807785afee48bb"},
		{"abc", "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"},
		{"message digest", "f7846f55cf23e14eebeab5b4e1550cad5b509e3348fbc4efa3a1413d393cb650"},
		{"abcdefghijklmnopqrstuvwxyz", "71c480df93d6ae2f1efad1447c66c9525e316218cf51fc8d9ed832f2daf18b73"},
		{"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789", "db4bfcbd4da0cd85a60c3c37d3fbd8805c77f15fc6b1fdfe614ee0a7c8fdb4c0"},
		{"12345678901234567890123456789012345678901234567890123456789012345678901234567890", "f371bc4a311f2b009eef952dd83ca80e2b60026c8e935592d0f9c308453c813e"}
	};

	size_t  i, j;
	uint8_t digest[32];
	char    hexdigest[64+1];

	for (i = 0; i < 7; i++) {
		sha256_buffer((uint8_t *)tv[i].input, strlen(tv[i].input), digest);
		for (j = 0; j < 32; j++)
			sprintf(hexdigest + 2 * j, "%02x", digest[j]);
		fprintf(stderr, "sha256('%s') = %s\n", tv[i].input, hexdigest);

		if (memcmp(tv[i].hexdigest, hexdigest, 64)) {
			fprintf(stderr, "ERROR\n");
			return 1;
		}
	}
	fprintf(stderr, "OK\n");

	if (argc > 1 && argv[1]) {
		uint8_t *buf = calloc(1, BUFDIM);
		struct timeval t0, t1;
		double interval;

#define PRINT_INTERVAL \
do { \
	interval = (double)(t1.tv_sec - t0.tv_sec) + (double)(t1.tv_usec - t0.tv_usec) / 1000000.0; \
	fprintf(stderr, "  run = %zd, time = %.3fs, speed = %.2f Mb/s\n", \
	        i, interval, (double)BUFDIM / (interval * 1024.0 * 1024.0)); \
} while (0)

		if (buf) {
			fprintf(stderr, "speed:\n");
			for (i = 0; i < RUNS; i++) {
				gettimeofday(&t0,  NULL);
				sha256_buffer(buf, BUFDIM, digest);
				gettimeofday(&t1,  NULL);
				PRINT_INTERVAL;
			}
			free(buf);
		}
	}

	return 0;
}
#endif