#include <stdio.h>

/* this is a ~30% win for CPUs with fast 64x64->64 multiplication, but a huge loss otherwise */
#define PARALLEL_PARITY 1

#define DATA_BITS    26
#define PARITY_BITS  6
#define EXTRA_BIT_POSITION (PARITY_BITS - 1)
#define CODE_BITS    (DATA_BITS + PARITY_BITS)
#define NUM_DATA_WORDS (1 << DATA_BITS)

/* 
 * Needed since we store all the parity at the end of the word, not at the expected
 * power-of-two bit positions.
 */
unsigned char permutation_table[CODE_BITS] = {
	0, 5, 4, 31, 3, 30, 29, 28, 2, 27, 26, 25, 24, 23, 22, 21, 1, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6
};

unsigned find_parity_32(unsigned x)
{
#if 0
	/* 
	 * This variant seems to be slightly faster, but depends on
	 * fast hardware multiplication.
	 */
	x = x ^ (x >> 1);
	x = (x ^ (x >> 2)) & 0x11111111;
	x = x * 0x11111111;
	return (x >> 28) & 1;
#else
	x ^= x >> 16;
	x ^= x >> 8;
	x ^= x >> 4;
	x &= 0xf;
	return (0x6996 >> x) & 1;
#endif
}

/* courtesy of neon/nocturnal :-) */
unsigned find_parity_32x2(unsigned a, unsigned b)
{
	unsigned long long x = (unsigned long long)a | (((unsigned long long)b)<<32);
	x = x ^ (x >> 1);
	x = (x ^ (x >> 2)) & 0x1111111111111111ULL;
	x = x * 0x11111111;
	return ((x>>28)&1) | ((x>>(32+28-1))&2);
}

unsigned generate_parity(unsigned data)
{
#if PARALLEL_PARITY
	return find_parity_32x2(data & 0x03b4e996, data & 0x00007fff) |
		(find_parity_32x2(data & 0x003f80ff, data & 0x01c78f0f) << 2) |
		(find_parity_32x2(data & 0x02d9b333, data & 0x036ad555) << 4);
#else
	unsigned parity1 = find_parity_32(data & 0x036ad555);
	unsigned parity2 = find_parity_32(data & 0x02d9b333);
	unsigned parity3 = find_parity_32(data & 0x01c78f0f);
	unsigned parity4 = find_parity_32(data & 0x003f80ff);
	unsigned parity5 = find_parity_32(data & 0x00007fff);
	unsigned parity6 = find_parity_32(data & 0x03b4e996);

	return parity6 | (parity5 << 1) | (parity4 << 2) | (parity3 << 3) | (parity2 << 4) | (parity1 << 5);
#endif
}

unsigned make_codeword(unsigned data)
{
	return (data << PARITY_BITS) | generate_parity(data);
}

/* can detect all single or double bit errors */
int has_error(unsigned code)
{
	unsigned parity = code & ((1 << PARITY_BITS) - 1);
	return (generate_parity(code >> PARITY_BITS) != parity);
}

int has_double_error(unsigned code)
{
	unsigned parity_diff = generate_parity(code >> PARITY_BITS) ^ code;
	return (parity_diff & ((1 << PARITY_BITS) - 1)) && !find_parity_32(code);
}

/* Correct any single-bit error -- assumes there are no double-bit errors */
unsigned correct_single_bit_error(unsigned code)
{
	unsigned parity_diff = generate_parity(code >> PARITY_BITS) ^ code;
	unsigned bp = 0, i;

	for (i = 0; i < PARITY_BITS - 1; ++i) {
		if (parity_diff & (1 << (PARITY_BITS - 1 - i))) {
			bp |= (1 << i);
		}
	}
	
	if (bp != 0) {
		/* flip the wrong bit */
		code ^= (1 << permutation_table[bp]);
	}

	/* recompute the lower parity */
	return (code & ~1) | find_parity_32(code & ~1);
}

void check_zero_bit_detection()
{
	unsigned i;
	printf("Checking zero bit detection.");
	fflush(stdout);

	for (i = 0; i < NUM_DATA_WORDS; ++i) {
		unsigned code = make_codeword(i);

		if ((i & 0xfffff) == 0) {
			printf(".");
			fflush(stdout);
		}
		
		if (has_error(code)) {
			printf("ERROR: Failed zero-bit test 1 for %x\n", i);
		}
		if (has_double_error(code)) {
			printf("ERROR: Failed zero-bit test 2 for %x\n", i);
		}
	}

	printf("\n");
}

void check_single_bit_detection()
{
	unsigned i, j;
	printf("Checking single bit detection and correction.");
	fflush(stdout);

	for (i = 0; i < NUM_DATA_WORDS; ++i) {
		unsigned code = make_codeword(i);

		if ((i & 0xfffff) == 0) {
			printf(".");
			fflush(stdout);
		}
		
		for (j = 0; j < CODE_BITS; ++j) {
			unsigned corrupted_code = code ^ (1 << j);
			
			if (!has_error(corrupted_code)) {
				printf("ERROR: Failed single-bit test 1 for %x with bit %u flipped\n", i, j);
			}
			if (has_double_error(corrupted_code)) {
				printf("ERROR: Failed single-bit test 2 for %x with bit %u flipped\n", i, j);
			}
			if (correct_single_bit_error(corrupted_code) != code) {
				printf("ERROR: Failed single-bit correction test for %x with bit %u flipped\n", i, j);
			}
		}
	}

	printf("\n");
}

void check_double_bit_detection()
{
	unsigned i, j, k;
	printf("Checking double bit detection.");
	fflush(stdout);

	for (i = 0; i < NUM_DATA_WORDS; ++i) {
		unsigned code = make_codeword(i);
		
		if ((i & 0xfffff) == 0) {
			printf(".");
			fflush(stdout);
		}
		
		for (j = 0; j < CODE_BITS; ++j) {
			for (k = 0; k < CODE_BITS; ++k) {
				unsigned corrupted_code = code ^ (1 << j) ^ (1 << k);
				if (j == k)
					continue;
				
				if (!has_error(corrupted_code)) {
					printf("ERROR: Failed double-bit test 1 for %x with bit %u and %u flipped\n", i, j, k);
				}
				if (!has_double_error(corrupted_code)) {
					printf("ERROR: Failed double-bit test 2 for %x with bit %u and %u flipped\n", i, j, k);
				}
			}
		}
	}

	printf("\n");
}

int main()
{
	check_zero_bit_detection();
	check_single_bit_detection();
	check_double_bit_detection();

	return 0;
}
