#include <stdio.h>
#include <time.h>
#include <stdlib.h>

#define DATA_BITS    64
#define PARITY_BITS  8
#define CODE_BITS    (DATA_BITS + PARITY_BITS)
#define SAMPLE_SIZE  1000000
#define SAMPLE_PROGRESS  100000

/* 
 * Needed since we store all the parity at the end of the word, not at the expected
 * power-of-two bit positions.
 */
int permutation_table[CODE_BITS + 1] = {
	-1, -1, -1, 63, -1, 62, 61, 60, -1, 59, 58, 57, 56, 55, 54, 53, -1, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, -1, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, -1, 6, 5, 4, 3, 2, 1, 0
};

unsigned find_parity_8(unsigned x)
{
	x ^= x >> 4;
	x &= 0xf;
	return (0x6996 >> x) & 1;
}

unsigned find_parity_64(unsigned long long x)
{
	unsigned y = (x ^ (x >> 32)) & 0xFFFFFFFFULL;
	y ^= y >> 16;
	y ^= y >> 8;
	return find_parity_8(y);
}

unsigned generate_parity(unsigned long long data)
{
	unsigned parity1 = find_parity_64(data & 0xDAB5556AAAAAAAD5ULL);
	unsigned parity2 = find_parity_64(data & 0xB66CCCD9999999B3ULL);
	unsigned parity3 = find_parity_64(data & 0x71E3C3C78787878FULL);
	unsigned parity4 = find_parity_64(data & 0x0FE03FC07F807F80ULL);
	unsigned parity5 = find_parity_64(data & 0x001FFFC0007FFF80ULL);
	unsigned parity6 = find_parity_64(data & 0x0000003FFFFFFF80ULL);
	unsigned parity7 = find_parity_64(data & 0x000000000000007FULL);
	unsigned parity8 = find_parity_64(data & 0xED3A65B4CB4B34E9ULL);

	return parity8 | (parity7 << 1) | (parity6 << 2) | (parity5 << 3) | (parity4 << 4) | (parity3 << 5) | (parity2 << 6) | (parity1 << 7);
}

/* can detect all single or double bit errors */
int has_error(unsigned long long data, unsigned parity)
{
	return (generate_parity(data) != parity);
}

int has_double_error(unsigned long long data, unsigned parity)
{
	return (generate_parity(data) != parity) && (find_parity_64(data) == find_parity_8(parity));
}

/* Correct any single-bit error -- assumes there are no double-bit errors */
unsigned long long correct_single_bit_error(unsigned long long data, unsigned parity)
{
	unsigned parity_diff = generate_parity(data) ^ parity;
	unsigned bp = 0, i;

	for (i = 0; i < PARITY_BITS - 1; ++i) {
		if (parity_diff & (1 << (PARITY_BITS - 1 - i))) {
			bp |= (1 << i);
		}
	}
	
	/* flip the wrong bit, if it's in the data */
	if (permutation_table[bp] != -1) {
		data ^= (1ULL << permutation_table[bp]);
	}

	return data;
}

void check_zero_bit_detection()
{
	unsigned i;
	printf("Checking zero bit detection.");
	fflush(stdout);

	for (i = 0; i < SAMPLE_SIZE; ++i) {
		unsigned long long data =
			((unsigned long long)(rand()) << 40) ^
			((unsigned long long)(rand()) << 20) ^
			((unsigned long long)(rand()));
		unsigned long long parity =
			generate_parity(data);

		if ((i % SAMPLE_PROGRESS) == 0) {
			printf(".");
			fflush(stdout);
		}
		
		if (has_error(data, parity)) {
			printf("ERROR: Failed zero-bit test 1 for %llx\n", data);
		}
		if (has_double_error(data, parity)) {
			printf("ERROR: Failed zero-bit test 2 for %llx\n", data);
		}
	}

	printf("\n");
}

void check_single_bit_detection()
{
	unsigned i, j;
	printf("Checking single bit detection and correction.");
	fflush(stdout);
	
	for (i = 0; i < SAMPLE_SIZE; ++i) {
		unsigned long long data =
			((unsigned long long)(rand()) << 40) ^
			((unsigned long long)(rand()) << 20) ^
			((unsigned long long)(rand()));
		unsigned long long parity =
			generate_parity(data);

		if ((i % SAMPLE_PROGRESS) == 0) {
			printf(".");
			fflush(stdout);
		}
		
		for (j = 0; j < CODE_BITS; ++j) {
			unsigned long long corrupted_data = data;
			unsigned corrupted_parity = parity;

			if (j < DATA_BITS) {
				corrupted_data ^= (1ULL << j);
			} else {
				corrupted_parity ^= (1 << (j - DATA_BITS));
			}
		
			if (!has_error(corrupted_data, corrupted_parity)) {
				printf("ERROR: Failed single-bit test 1 for %llx with bit %u flipped\n", data, j);
			}
			if (has_double_error(corrupted_data, corrupted_parity)) {
				printf("ERROR: Failed single-bit test 2 for %llx with bit %u flipped\n", data, j);
			}
			if (correct_single_bit_error(corrupted_data, corrupted_parity) != data) {
				printf("ERROR: Failed single-bit correction test for %llx with bit %u flipped\n", data, j);
			}
		}
	}

	printf("\n");
}

void check_double_bit_detection()
{
	unsigned i, j, k;
	printf("Checking double bit detection.");
	fflush(stdout);
	
	for (i = 0; i < SAMPLE_SIZE; ++i) {
		unsigned long long data =
			((unsigned long long)(rand()) << 40) ^
			((unsigned long long)(rand()) << 20) ^
			((unsigned long long)(rand()));
		unsigned long long parity =
			generate_parity(data);

		if ((i % SAMPLE_PROGRESS) == 0) {
			printf(".");
			fflush(stdout);
		}
		
		for (j = 0; j < CODE_BITS; ++j) {
			for (k = j + 1; k < CODE_BITS; ++k) {
				unsigned long long corrupted_data = data;
				unsigned corrupted_parity = parity;
				if (j < DATA_BITS) {
					corrupted_data ^= (1ULL << j);
				} else {
					corrupted_parity ^= (1 << (j - DATA_BITS));
				}
				if (k < DATA_BITS) {
					corrupted_data ^= (1ULL << k);
				} else {
					corrupted_parity ^= (1 << (k - DATA_BITS));
				}
				
				if (!has_error(corrupted_data, corrupted_parity)) {
					printf("ERROR: Failed double-bit test 1 for %llx with bit %u and %u flipped\n", data, j, k);
				}
				if (!has_double_error(corrupted_data, corrupted_parity)) {
					printf("ERROR: Failed double-bit test 2 for %llx with bit %u and %u flipped\n", data, j, k);
				}
			}
		}
	}

	printf("\n");
}

int main()
{
	srand(time(NULL));
	check_zero_bit_detection();
	check_single_bit_detection();
	check_double_bit_detection();

	return 0;
}
