November 24, 2025

Fast Reed Solomon erase code in Go

Here's the Go (golang) version. It is about one third slower than the C version. It uses unsafe pointers to do 64-bit load and stores.

Go version, M2 MacBook Air:

K=16, N=24, L=64KB (1MB) encoding time 0.0023s. Input rate: 430 MB/s.
K=16, N=24, L=1MB (16MB) encoding time 0.037s. Input rate: 430 MB/s.
K=17, N=20, L=64MB (1088MB) encoding time 1s. Input rate: 1070 MB/s.
K=17, N=20, L=1MB (17MB) encoding time 0.015s. Input rate: 1130 MB/s.
K=17, N=20, L=200KB (3.3MB) encoding time 0.0029s. Input rate: 1140 MB/s.

test_all(17, 20, 1024) # 1140 combinations. 17KB input. 0.017s.
test_all(10, 16, 1024) # 8008 combinations. 10KB input. 0.1s.
test_all(16, 24, 16) # 735471 combinations. 256B input. 3s.
test_all(16, 24, 1024) # 735471 combinations. 16KB input. 21.4s.


package main

import "bytes"
import "fmt"
import "math/rand"
import "time"
import "unsafe"

type bytearray = []uint8
type b8 = uint64

var gf256_log [256]byte
var gf256_antilog [255]byte
var gf256_mul_table [256 * 256]byte

func gf256_slow_mul(x byte, y byte) byte { // long multiplication
	var r byte = 0
	for y > 0 {
		if y&1 != 0 {
			r ^= x // add
		}
		if x&0x80 != 0 {
			x = (x << 1) ^ 0x1b // reduce by polynomial 0x11b
		} else {
			x <<= 1
		}
		y >>= 1
	}
	return r
}

func gf256_init_tables() {
	var z byte = 1
	for x := range 255 {
		if z == 0 {
			panic("z is zero")
		}
		gf256_antilog[x] = z
		gf256_log[z] = byte(x)
		z = gf256_slow_mul(z, 3)
	}
	for r := range 256 {
		for c := range 256 {
			gf256_mul_table[(r<<8)|c] = gf256_slow_mul(byte(r), byte(c))
		}
	}
}

func gf256_div(x byte, y byte) byte {
	if y == 0 {
		panic("divide by zero")
	}
	if x == 0 {
		return 0
	}
	return gf256_antilog[(int(gf256_log[x])-int(gf256_log[y])+255)%255]
}

func gf256_mul(x byte, y byte) byte {
	return gf256_mul_table[(int(x)<<8)|int(y)]
}

func lag_interp_weights(K int, xs []byte, x byte) [][]byte {
	// assume len(xs) == K
	// assume len(w) == K
	w := make([][]byte, K)
	for i := range K {
		t := byte(1)
		for j := range K {
			if i != j {
				t = gf256_mul(t, gf256_div((x^xs[j]), xs[i]^xs[j]))
			}
		}
		z := int(t) << 8
		w[i] = gf256_mul_table[z : z+256]
	}
	return w
}

func compute_chunks(K int, xs []byte, L int, chunks []bytearray, T int, ts []byte) []bytearray {
	// assume len(xs) == K
	// assume len(chunks) == K
	// assume len(chunks[i]) == L for i = 0 to K-1
	// assume len(ts) == T
	// assume len(results) == T
	// assume len(results[i]) == L for i = 0 to T-1
	// assume not x in xs for x in ts
	if L%8 != 0 {
		panic(fmt.Sprintf("L must be a multiple of 8: %d", L))
	}
	inputs := make([]bytearray, K)
	for i := range K {
		inputs[i] = chunks[xs[i]]
		if uintptr(unsafe.Pointer(&inputs[i][0]))%8 != 0 {
			panic("input chunks must be aligned to 8 byte boundaries")
		}
	}
	weights := make([][][]byte, T)
	for j := range T {
		weights[j] = lag_interp_weights(K, xs, ts[j])
	}
	results := make([]bytearray, T)
	for j := range T {
		results[j] = make(bytearray, L)
		if uintptr(unsafe.Pointer(&results[j][0]))%8 != 0 {
			panic("allocated result chunk is not aligned to 8 byte boundaries")
		}
	}
	for j := range T {
		r := results[j]
		ww := weights[j]
		for i := range K {
			inp := inputs[i]
			w := ww[i]
			for z := 0; z < L; z += 8 {
				dd := *(*b8)(unsafe.Pointer(&inp[z]))
				d1 := byte(dd)
				d2 := byte(dd >> 8)
				d3 := byte(dd >> 16)
				d4 := byte(dd >> 24)
				d5 := byte(dd >> 32)
				d6 := byte(dd >> 40)
				d7 := byte(dd >> 48)
				d8 := byte(dd >> 56)
				*(*b8)(unsafe.Pointer(&r[z])) ^=
						b8(w[d1]) |
						b8(w[d2])<<8 |
						b8(w[d3])<<16 |
						b8(w[d4])<<24 |
						b8(w[d5])<<32 |
						b8(w[d6])<<40 |
						b8(w[d7])<<48 |
						b8(w[d8])<<56
			}
		}
	}
	return results
}

func now() time.Time {
	return time.Now()
}

func encode_chunks(K int, N int, L int, chunks []bytearray) []bytearray {
	// assume len(chunks) == N.
	// assume len(chunks[x]) == L for x = 0 to K-1
	// assume chunk[x] == NULL, for x = K to N-1
	// computed chunks are stored in chunks[x] for x = K to N-1
	if !(K > 0 && K < N && N <= 256) {
		panic("Invalid K or N")
	}
	xs := make([]byte, K)
	for i := range K {
		xs[i] = byte(i)
	}
	T := N - K
	ts := make([]byte, T)
	for t := range T {
		ts[t] = byte(K + t)
	}
	result := compute_chunks(K, xs, L, chunks, T, ts)
	chunks2 := make([]bytearray, N)
	copy(chunks2, chunks)
	for i := K; i < N; i++ {
		chunks2[i] = result[i-K]
	}
	return chunks2
}

func decode_chunks(K int, N int, L int, chunks []bytearray) []bytearray {
	// assume len(chunks) == N.
	// assume at least K non-NULL chunks[x] for x = 0 to N-1
	// assume len(chunks[x]) == L for x = 0 to N-1 if chunks[x] is not NULL
	// computed chunks are stored in chunks[x] for x = 0 to K-1 if chunks[x] is NULL
	if !(K > 0 && K < N && N <= 256) {
		panic("Invalid K or N")
	}
	xs := make([]byte, K)
	j := 0
	for x := 0; x < N && j < K; x++ {
		if chunks[x] != nil {
			xs[j] = byte(x)
			j++
		}
	}
	if j != K {
		panic(fmt.Sprintf("Only found %d valid chunks. Need %d chunks", j, K))
	}
	ts := make([]byte, N-K)
	T := 0
	for x := range K {
		if len(chunks[x]) == 0 {
			ts[T] = byte(x)
			T++
		}
	}
	results := compute_chunks(K, xs, L, chunks, T, ts)
	chunks2 := make([]bytearray, K)
	copy(chunks2, chunks)
	for x := range T {
		chunks2[ts[x]] = results[x]
	}
	return chunks2
}

var temp bytearray

const TEMP_SIZE = 256 * 1024 * 1024

func setup_temp() {
	temp = make(bytearray, TEMP_SIZE)
}

func clear_cache() {
	for i := range TEMP_SIZE {
		temp[i] = byte(i)
	}
	fmt.Println("Cache cleared")
}

func gen_data(K int, N int, L int) []bytearray {
	rand.Seed(1)
	r := make([]bytearray, N)
	for i := range K {
		r[i] = make(bytearray, L)
		for j := range L {
			r[i][j] = byte(rand.Intn(256))
		}
	}
	return r
}

func chunks_equal(K int, chunks1 []bytearray, chunks2 []bytearray) bool {
	for x := range K {
		if !bytes.Equal(chunks1[x], chunks2[x]) {
			fmt.Printf("Chunks %d are not equal\n", x)
			return false
		}
	}
	return true
}

func test_random(K int, N int, L int, encode_rounds int, decode_rounds int) {
	test_chunks := make([][]bytearray, encode_rounds)
	for r := range encode_rounds {
		test_chunks[r] = gen_data(K, N, L)
	}
	clear_cache()
	estart := now()
	for r := range encode_rounds {
		test_chunks[r] = encode_chunks(K, N, L, test_chunks[r])
	}
	encode_time := now().Sub(estart).Seconds()
	fmt.Printf("Encode time: %fs %d rounds %fs/round  input rate: %.3fMB/s\n", encode_time, encode_rounds, encode_time/float64(encode_rounds), float64(K)*float64(L)*float64(encode_rounds)/1024/1024/encode_time)
	clear_cache()
	dstart := now()
	rand.Seed(2)
	for round := range decode_rounds {
		chunks := test_chunks[round%encode_rounds]
		temp := make([]bytearray, N)
		for _ = range K {
			x := 0
			for ok := true; ok; ok = (len(temp[x]) != 0) {
				x = rand.Intn(N)
			}
			temp[x] = chunks[x]
		}
		results := decode_chunks(K, N, L, temp)
		if !chunks_equal(K, chunks, results) {
			panic("chunks not equal")
		}
	}
	decode_time := now().Sub(dstart).Seconds()
	fmt.Printf("Decode time: %fs %d rounds %fs/round  input rate: %.3fMB/s\n", decode_time, decode_rounds, decode_time/float64(decode_rounds), float64(K)*float64(L)*float64(decode_rounds)/1024/1024/decode_time)
	fmt.Printf("random_%d_%d_%d succeeded\n", K, N, L)
}

func start_comb(K int) []int {
	v := make([]int, K)
	for i := range K {
		v[i] = i
	}
	return v
}

func next_comb(K int, N int, v []int) bool {
	x := K - 1
	for x >= 0 {
		if v[x] < N-1 {
			v[x] += 1
			for j := x + 1; j < K; j++ {
				v[j] = v[j-1] + 1
			}
			if v[K-1] < N {
				return false
			}
		}
		x -= 1
	}
	return true
}

func test_all(K int, N int, L int) {
	chunks := gen_data(K, N, L)
	clear_cache()
	estart := now()
	chunks = encode_chunks(K, N, L, chunks)
	encode_time := now().Sub(estart).Seconds()
	fmt.Printf("Encode time: %fs  input rate: %.3fMB/s\n", encode_time, float64(K)*float64(L)/1024/1024/encode_time)
	clear_cache()
	dstart := now()
	count := 0
	valid := start_comb(K)
	for done := false; !done; done = next_comb(K, N, valid) {
		temp := make([]bytearray, N)
		for i := range K {
			temp[valid[i]] = chunks[valid[i]]
		}
		count += 1
		results := decode_chunks(K, N, L, temp)
		if !chunks_equal(K, chunks, results) {
			panic("chunks not equal")
		}
	}
	decode_time := now().Sub(dstart).Seconds()
	fmt.Printf("Decode time: %fs %d combinations %fs/combination  input rate: %.3fMB/s\n", decode_time, count, decode_time/float64(count), float64(K)*float64(L)*float64(count)/1024/1024/decode_time)
	fmt.Printf("Encode time plus decode time: %f\n", encode_time+decode_time)
	fmt.Printf("all_%d_%d_%d %d combinations succeeded\n", K, N, L, count)
}

func main() {
	gf256_init_tables()
	setup_temp()

	test_random(16, 24, 65536, 100, 100)
	test_random(16, 24, 1024*1024, 10, 10)
	test_random(17, 20, 64*1024*1024, 4, 1)
	test_random(17, 20, 1024*1024, 10, 10)
	test_random(17, 20, 200*1024, 10, 10)

	test_all(17, 20, 1024)
	test_all(10, 16, 1024)
	test_all(16, 24, 16)
	test_all(16, 24, 1024)
}





No comments:

Post a Comment