#include <stdio.h>

#define DATA_BITS    11
#define PARITY_BITS  5
#define EXTRA_BIT_POSITION (PARITY_BITS - 1)
#define CODE_BITS    (DATA_BITS + PARITY_BITS)
#define NUM_DATA_WORDS (1 << DATA_BITS)

unsigned char hamming_lookup[NUM_DATA_WORDS];

/* 
 * Needed since we store all the parity at the end of the word, not at the expected
 * power-of-two bit positions. This is the inverse of the mapping
 * (0..15) -> (0, 8, 4, 2, 1, the rest in ascending order)
 */
unsigned char permutation_table[CODE_BITS] = {
	0, 4, 3, 5, 2, 6, 7, 8, 1, 9, 10, 11, 12, 13, 14, 15
};

unsigned generate_parity(unsigned data)
{
	unsigned bits[DATA_BITS];
	unsigned parity[PARITY_BITS];
	unsigned i;

	parity[EXTRA_BIT_POSITION] = 0;
	
	for (i = 0; i < DATA_BITS; ++i) {
		bits[i] = (data & (1 << i)) ? 1 : 0;
		parity[EXTRA_BIT_POSITION] ^= bits[i];
	}

	parity[0] = bits[0] ^ bits[1] ^ bits[3] ^ bits[4] ^ bits[6] ^ bits[8] ^ bits[10];
	parity[1] = bits[0] ^ bits[2] ^ bits[3] ^ bits[5] ^ bits[6] ^ bits[9] ^ bits[10];
	parity[2] = bits[1] ^ bits[2] ^ bits[3] ^ bits[7] ^ bits[8] ^ bits[9] ^ bits[10];
	parity[3] = bits[4] ^ bits[5] ^ bits[6] ^ bits[7] ^ bits[8] ^ bits[9] ^ bits[10];
	parity[EXTRA_BIT_POSITION] ^= parity[0] ^ parity[1] ^ parity[2] ^ parity[3];

	return parity[EXTRA_BIT_POSITION] | (parity[3] << 1) | (parity[2] << 2) | (parity[1] << 3) | (parity[0] << 4);
}

unsigned make_codeword(unsigned data)
{
	return (data << PARITY_BITS) | hamming_lookup[data];
}

void generate_lookup()
{
	unsigned i;

	printf("Generating lookup table.\n");
	
	for (i = 0; i < NUM_DATA_WORDS; ++i) {
		hamming_lookup[i] = generate_parity(i);
	}
}

/* can detect all single or double bit errors */
int has_error(unsigned code)
{
	unsigned data = code >> PARITY_BITS;
	unsigned parity = code & ((1 << PARITY_BITS) - 1);

	return (hamming_lookup[data] != parity);
}

int has_double_error(unsigned code)
{
	unsigned i;
	unsigned data = code >> PARITY_BITS;
	unsigned parity = code & ((1 << PARITY_BITS) - 1);
	unsigned gen_parity = hamming_lookup[data];

	unsigned hamming_parity = parity >> 1;
	unsigned gen_hamming_parity = gen_parity >> 1;
	unsigned extra_parity = 0;

	/* check the lowest parity bit */
	for (i = 0; i < CODE_BITS; ++i) {
		extra_parity ^= (code & 1);
		code >>= 1;
	}

	/* no errors at all (user should have used has_error() first; boo, hiss) */
	if (hamming_parity == gen_hamming_parity && extra_parity == 0)
		return 0;

	/* both hamming and simple parity errors; this is a single-bit error */
	if (hamming_parity != gen_hamming_parity && extra_parity == 1)
		return 0;

	/* hamming says OK, but simple parity indicates an error => simple parity error is wrong */
	if (hamming_parity == gen_hamming_parity && extra_parity == 1)
		return 0;

	/* hamming says error, simple parity says OK => DOUBLE ERROR */
	return 1;
}

/* Correct any single-bit error -- assumes there are no double-bit errors */
unsigned correct_single_bit_error(unsigned code)
{
	unsigned bits[CODE_BITS];
	unsigned parity[PARITY_BITS];
	unsigned i, bp = 0;

	parity[EXTRA_BIT_POSITION] = 0;

	for (i = 0; i < CODE_BITS; ++i) {
		bits[i] = (code & (1 << i)) ? 1 : 0;
	}
	for (i = 1; i < CODE_BITS; ++i) {
		parity[EXTRA_BIT_POSITION] ^= bits[i];
	}

	parity[0] = bits[PARITY_BITS+0] ^ bits[PARITY_BITS+1] ^ bits[PARITY_BITS+3] ^ bits[PARITY_BITS+4] ^ bits[PARITY_BITS+6] ^ bits[PARITY_BITS+8] ^ bits[PARITY_BITS+10];
	parity[1] = bits[PARITY_BITS+0] ^ bits[PARITY_BITS+2] ^ bits[PARITY_BITS+3] ^ bits[PARITY_BITS+5] ^ bits[PARITY_BITS+6] ^ bits[PARITY_BITS+9] ^ bits[PARITY_BITS+10];
	parity[2] = bits[PARITY_BITS+1] ^ bits[PARITY_BITS+2] ^ bits[PARITY_BITS+3] ^ bits[PARITY_BITS+7] ^ bits[PARITY_BITS+8] ^ bits[PARITY_BITS+9] ^ bits[PARITY_BITS+10];
	parity[3] = bits[PARITY_BITS+4] ^ bits[PARITY_BITS+5] ^ bits[PARITY_BITS+6] ^ bits[PARITY_BITS+7] ^ bits[PARITY_BITS+8] ^ bits[PARITY_BITS+9] ^ bits[PARITY_BITS+10];
	
	for (i = 0; i < PARITY_BITS - 1; ++i) {
		if (parity[i] != bits[PARITY_BITS - 1 - i]) {
			bp |= (1 << i);
		}
	}
	
	if (bp != 0) {
		/* flip the wrong bit */
		code ^= (1 << permutation_table[bp]);
		parity[EXTRA_BIT_POSITION] ^= 1;
	}

	/* recompute the lower parity */
	return (code & ~1) | parity[EXTRA_BIT_POSITION];
}

void check_zero_bit_detection()
{
	unsigned i;
	printf("Checking zero bit detection.\n");

	for (i = 0; i < NUM_DATA_WORDS; ++i) {
		unsigned code = make_codeword(i);
		if (has_error(code)) {
			printf("ERROR: Failed zero-bit test 1 for %x\n", i);
		}
		if (has_double_error(code)) {
			printf("ERROR: Failed zero-bit test 2 for %x\n", i);
		}
	}
}

void check_single_bit_detection()
{
	unsigned i, j;
	printf("Checking single bit detection and correction.\n");

	for (i = 0; i < NUM_DATA_WORDS; ++i) {
		unsigned code = make_codeword(i);
		for (j = 0; j < CODE_BITS; ++j) {
			unsigned corrupted_code = code ^ (1 << j);
			
			if (!has_error(corrupted_code)) {
				printf("ERROR: Failed single-bit test 1 for %x with bit %u flipped\n", i, j);
			}
			if (has_double_error(corrupted_code)) {
				printf("ERROR: Failed single-bit test 2 for %x with bit %u flipped\n", i, j);
			}
			if (correct_single_bit_error(corrupted_code) != code) {
				printf("ERROR: Failed single-bit correction test for %x with bit %u flipped\n", i, j);
			}
		}
	}
}

void check_double_bit_detection()
{
	unsigned i, j, k;
	printf("Checking double bit detection.\n");

	for (i = 0; i < NUM_DATA_WORDS; ++i) {
		unsigned code = make_codeword(i);
		for (j = 0; j < CODE_BITS; ++j) {
			for (k = 0; k < CODE_BITS; ++k) {
				unsigned corrupted_code = code ^ (1 << j) ^ (1 << k);
				if (j == k)
					continue;
				
				if (!has_error(corrupted_code)) {
					printf("ERROR: Failed double-bit test 1 for %x with bit %u and %u flipped\n", i, j, k);
				}
				if (!has_double_error(corrupted_code)) {
					printf("ERROR: Failed double-bit test 2 for %x with bit %u and %u flipped\n", i, j, k);
				}
			}
		}
	}
}

int main()
{
	generate_lookup();
	check_zero_bit_detection();
	check_single_bit_detection();
	check_double_bit_detection();

	return 0;
}
