November 09, 2025

Reed Solomon erasure code library

In part 1, I've visually demonstrated a version of the Reed Solomon erasure code.

Let's turn that into a little library to encode and decode chunks of data.

API

encode(K, N, chunks): K and N are as explained in the previous blog. Chunks is a dictionary containing the data to encode. The keys are exactly 1 to K. The values are the data for that chunk (list of numbers, each between 0 to 255). All the chunks must be of the same length (L). The function will return a dictionary of N-K redundant chunks in the same format as chunks. The keys are from K+1 to N.

decode(K, N, chunks): K, N are the same as above. Chunks is a dictionary of data with a format like above. It should contain data for any K of the N chunks (original plus redundant). The key range is from 1 to N. It returns a dictionary containing all missing chunks between 1 to K.

The Code

The code from the previous post looks slow. There are about 2*K*K multiplications/divisions and 3*K additions/subtractions for each output byte. Well, the code was written that way to match the mathematical explanation of Lagrange Interpolation and to plot the curves. Let's make that faster.

Notice that we get to pick the x's for the polynomials. The x's is 1 to K for the input chunks and K+1 to N for the redundant chunks. We loop through L slices of the chunks (one byte from each chunk) using the same x's. We extract this loop invariant part into lag_interp_vec(xs, x). This returns a vector of weights for the y's.

To compute each redundant chunk, we compute this vector of weights once. Then we take the y values (one from each of the K chunks at the same offset),  multiply each y with the corresponding weight in the vector and sum the results. This is effectively a dot product of the weight vector and a slice of the K chunks. We only need to do K multiplications and K additions for each output byte. We also skip a whole bunch of Python overhead.

"Decoding" the chunks is similar to encoding the chunks. As long as we have K of the N chunks, we can use the same math to compute the data for any missing key between 1 to K. 

Here's the code, including a test function test_block() to check that it actually works.


import random
import time

# GF(2**8)
gf256_log = [0]*256 # zeroth entry is invalid
gf256_antilog = [0]*255 # yes, only 255 entrie

def gf256_add(x, y):
    return x^y

gf256_sub = gf256_add

def gf256_smul(x, y): # long multiplication
    r = 0
    while y>0:
        if y&1:
            r ^= x # add
        if x&0x80:
            x = (x<<1)^0x11b # reduce by polynomial
        else:
            x <<= 1
        y >>= 1
    return r

def gf256_init_tables():
    z = 1
    for x in range(255):
        gf256_antilog[x] = z
        gf256_log[z] = x
        z =  gf256_smul(z, 3)

def gf256_div(x, y):
    assert y != 0
    if x == 0:
        return 0
    return gf256_antilog[(gf256_log[x] - gf256_log[y])]

def gf256_mul(x, y):
    if x == 0 or y == 0:
        return 0
    return gf256_antilog[(gf256_log[x] + gf256_log[y]) % 255]

gf256_init_tables()

def lag_interp_vec(xs, x):
    n = len(xs)
    v = n * [0]
    for i in range(n):
        t = 1
        for j in range(n):
            if i != j:
                t = gf256_mul(t,
                              gf256_div(gf256_sub(x, xs[j]),
                                        gf256_sub(xs[i], xs[j])))
        v[i] = t
    return v

def compute_chunk(K, x, xs, chunks):
    assert not x in xs
    assert K == len(xs)
    vec = lag_interp_vec(xs, x)
    L = len(chunks[xs[0]])
    result = L * [0]
    for z in range(L):
        y = 0
        for i in range(K):
            y = gf256_add(y, gf256_mul(vec[i], chunks[xs[i]][z]))
        result[z] = y
    return result

def encode_chunks(K, N, chunks):
    assert K > 0 and K < N and N < 256
    L = len(chunks[1])
    for z in range(1, K+1):
        assert len(chunks[z]) == L
    start = time.time()
    xs = list(range(1, K+1))
    encoded = {}
    for x in range(K+1, N+1):
        encoded[x] = compute_chunk(K, x, xs, chunks)
    dur = time.time()-start
    rsize = (N-K)*L
    rsize_rate = rsize/dur/1024
    dsize = K*L
    dsize_rate = dsize/dur/1024
    print(f"Encode {dur:.5f}s  data {dsize} rate {dsize_rate:.0f} kb/s  redundancy {rsize} rate {rsize_rate:.0f} kb/s")
    return encoded

def decode_chunks(K, N, chunks):
    start = time.time()
    assert K > 0 and K < N and N < 256
    xs = []
    for x in range(1, N+1):
        if x in chunks:
            xs.append(x)
        if len(xs) >= K:
            break
    assert len(xs) == K
    r = {}
    for x in range(1, K+1):
        if not x in chunks:
            r[x] = compute_chunk(K, x, xs, chunks)
    print(f"Decode time {time.time()-start:.5f}  recovered {len(r)}")
    return r

def recover_original_chunks(K, N, chunks):
    temp = decode_chunks(K, N, chunks)
    temp.update(chunks)
    for x in [j for j in temp.keys() if j > K]:
        del temp[x]
    return temp

def test_block(K, N, chunks, rounds, name):
    rr = random.Random(2)
    enc = encode_chunks(K, N, chunks)
    enc.update(chunks)
    for round in range(rounds):
        valid = list(range(1, N+1))
        while len(valid) > K:
            valid.pop(rr.randint(0, len(valid)-1))
        temp = {}
        for x in valid:
            temp[x] = enc[x]
        r = recover_original_chunks(K, N, temp)
        assert r == chunks, f"Recovered chunks are not equal the original chunks"
    print(f"{name} succeeded")
                
def gen_data(K, L):
    rr = random.Random(1)
    chunks = { x: list(rr.randbytes(L)) for x in range(1, K+1) }
    return chunks

test_block(16, 24, gen_data(16, 65536), 1, "test")

That's all. Just over 100 lines of Python code.

With K=16, N=24, encoding 1 MB (16 chunks of 64KB) takes 1.8s (550KB/s) on my M2 MacBook Air.

I have not found an erasure code algorithm that does less work than one dot product per output byte. Certainly, I haven't found one that is so easy to code!

No comments:

Post a Comment