November 09, 2025

Systematic Reed Solomon erasure code in Python

In part 1, I've visually demonstrated a version of the Reed Solomon erasure code.

Let's turn that into a little library to encode and decode chunks of data. It should be something that can be used in a large scale cloud data storage system with millions of storage devices.

API

encode_chunks(K, N, chunks)

K is the number of input chunks. 

N is the total number of chunks wanted.

chunks is a dictionary containing the data to encode. The keys has to be exactly 0 to K-1. The values should be a list of numbers, each between 0 to 255. All the chunks must be of the same length L.

The function will return a dictionary of N-K redundant chunks in the same format as chunks. The keys are from K to N-1.

decode_chunks(K, N, chunks)

K is the number of chunks.

N is the total number of chunks.

chunks is a dictionary of at least K out of N chunks. The keys has to be from 0 to N-1

The function returns a dictionary containing all missing chunks between 0 to K-1.

The Code

The code from the previous post looks slow. There are about 2*K*K "multiplications/divisions" and 3*K "additions/subtractions" for each output byte in the redundant chunks. Well, the code was written that way to match the mathematical equation of Lagrange Interpolation and to plot the curves. Let's make that faster.

Notice that we get to pick the x's for the polynomials. The x's are the keys for the chunks and result dictionaries. It is 0 to K-1 for the input chunks and K to N-1 for the redundant chunks. We loop through L slices of the chunks (one byte from each chunk) using the same x's. We extract this loop invariant part into lag_interp_weights(xs, x). This returns a vector of weights for the y's. These weights are actually the coefficients of the Lagrange polynomials.

For each redundant chunk, we compute this vector of weights once. Then we take a slice (one byte from each of the K chunks at the same offset) as the y values,  multiply each y with the corresponding weight in the vector and sum them together. In other words, the output byte for the slice is the dot product of the weight vector and a slice of the K chunks. Do that for all L slices to compute the redundant chunk. We only need to do K "multiplications" and K "additions" for each output byte.

Decoding the chunks is similar to encoding the chunks. As long as we have K of the N chunks, we use the keys for those chunks as the x's, use the same math to compute the data for any missing key between 0 to K-1

Here's the code, including a test function test_random() to check that it actually works.


import random
import time

# GF(2**8)
gf256_log = [0]*256 # zeroth entry is invalid
gf256_antilog = [0]*255 # yes, only 255 entries

def gf256_add(x, y):
    return x^y

gf256_sub = gf256_add

def gf256_slow_mul(x, y): # long multiplication
    r = 0
    while y>0:
        if y&1:
            r ^= x # add
        if x&0x80:
            x = (x<<1)^0x11b # reduce by polynomial
        else:
            x <<= 1
        y >>= 1
    return r

def gf256_init_tables():
    z = 1
    for x in range(255):
        gf256_antilog[x] = z
        gf256_log[z] = x
        z =  gf256_slow_mul(z, 3)

def gf256_div(x, y):
    assert y != 0
    if x == 0:
        return 0
    return gf256_antilog[(gf256_log[x] - gf256_log[y])]

def gf256_mul(x, y):
    if x == 0 or y == 0:
        return 0
    return gf256_antilog[(gf256_log[x] + gf256_log[y]) % 255]

gf256_init_tables()

def lag_interp_weights(xs, x):
    n = len(xs)
    w = n * [0]
    for i in range(n):
        t = 1
        for j in range(n):
            if i != j:
                t = gf256_mul(t,
                              gf256_div(gf256_sub(x, xs[j]),
                                        gf256_sub(xs[i], xs[j])))
        w[i] = t
    return w

def compute_chunks(K, xs, ts, chunks):
    assert K == len(xs)
    results = {}
    L = len(chunks[xs[0]])
    inputs = [chunks[xs[i]] for i in range(K)]
    for t in ts:
        assert not t in xs
        w = lag_interp_weights(xs, t)
        r = L * [0]
        for z in range(L):
            y = 0
            for i in range(K):
                y = gf256_add(y, gf256_mul(w[i], inputs[i][z]))
            r[z] = y
        results[t] = r
    return results

def encode_chunks(K, N, chunks):
    assert K > 0 and K < N and N <= 256
    L = len(chunks[0])
    for z in range(K):
        assert len(chunks[z]) == L
    xs = list(range(K))
    return compute_chunks(K, xs, list(range(K, N)), chunks)

def decode_chunks(K, N, chunks):
    assert K > 0 and K < N and N <= 256
    xs = []
    for x in range(N):
        if x in chunks:
            xs.append(x)
        if len(xs) >= K:
            break
    assert len(xs) == K
    ts = [x for x in range(K) if not x in chunks]
    return compute_chunks(K, xs, ts, chunks)

def recover_original_chunks(K, N, chunks):
    temp = decode_chunks(K, N, chunks)
    temp.update(chunks)
    for x in [j for j in temp.keys() if j >= K]:
        del temp[x]
    return temp

def test_random(K, N, L, rounds):
    chunks = gen_data(K, L)
    rr = random.Random(2)
    estart = time.time()
    enc = encode_chunks(K, N, chunks)
    encode_time = time.time() - estart;
    print(f"Encode time: {encode_time:f}s  input rate: {K*L/1024/1024/encode_time:.3f}MB/s")
    enc.update(chunks)
    dstart = time.time()
    for round in range(rounds):
        valid = list(range(N))
        while len(valid) > K:
            valid.pop(rr.randint(0, len(valid)-1))
        temp = {}
        for x in valid:
            temp[x] = enc[x]
        r = recover_original_chunks(K, N, temp)
        assert r == chunks, f"Recovered chunks are not equal the original chunks: {valid}"
    decode_time = time.time() - dstart;
    print(f"Decode time: {decode_time:f}s {rounds} rounds {decode_time/rounds:f}s/round  input rate: {K*L*rounds/1024/1024/decode_time:.3f}MB/s")
    print(f"random_{K}_{N}_{L} succeeded")

def gen_data(K, L):
    rr = random.Random(1)
    chunks = { x: list(rr.randbytes(L)) for x in range(K) }
    return chunks

test_random(16, 24, 65536, 1)

That's all. Just over 100 lines of Python code not counting the test functions.

No solving of systems of linear equations. No Vandermoonde matrix. No Gaussian elimination over Galios Fields to invert matrices. No dependencies. Just plain Python.

With K=16, N=24, encoding 1 MB (16 chunks of 64 KB) takes 1.7s (600 KB/s) on my M2 MacBook Air.

With K=17, N=20, encoding 17 MB (17 chunks of 1 MB) takes 11.2s (1.5 MB/s).

I have not seen an erasure code algorithm that does less work than one dot product per output byte.
Certainly, I haven't found one that is so easy to code!


Update 2:

See fast C version and fast Go lang version.


Update:

I added test_all() to test that the data can be recovered from all possible N-choose-K subsets of chunks.


def start_comb(K, v):
    for i in range(K):
        v[i] = i

def next_comb(K, N, v):
    x = K-1
    while x >= 0:
        if v[x] < N-1:
            v[x] += 1
            for j in range(x+1, K):
                v[j] = v[j-1] + 1
            if v[K-1] < N:
                return False
        x -= 1
    return True

def test_all(K, N, L):
    chunks = gen_data(K, L)
    estart = time.time()
    enc = encode_chunks(K, N, chunks)
    encode_time = time.time() - estart;
    print(f"Encode time: {encode_time:f}s  input rate: {K*L/1024/1024/encode_time:.3f}MB/s")
    enc.update(chunks)
    dstart = time.time()
    valid = K * [0]
    count = 0
    done = False
    start_comb(K, valid)
    while not done:
        temp = {}
        count += 1
        for x in valid:
            temp[x] = enc[x]
        r = recover_original_chunks(K, N, temp)
        assert r == chunks, f"Recovered chunks are not equal the original chunks: {valid}"
        done = next_comb(K, N, valid)
    decode_time = time.time() - dstart;
    print(f"Decode time: {decode_time:f}s {count} combinations {decode_time/count:f}s/combination  input rate: {K*L*count/1024/1024/decode_time:.3f}MB/s")
    print(f"Encode time plus decode time: {encode_time+decode_time:f}s")
    print(f"all_{K}_{N}_{L} {count} combinations succeeded")

test_all(17, 20, 1024) # 1140 combinations. 17KB input. 10 seconds.
test_all(10, 16, 1024) # 8008 combinations. 10KB input. 1 minute.
test_all(16, 24, 16) # 735471 combinations. 256B input. 9.5 minutes.
test_all(16, 24, 1024) # 735471 combinations. 16KB input. 5 hours.


2 comments:

  1. Benchmarks against existing implementations?

    ReplyDelete
  2. This python code is not meant to be fast. It is meant to match the math in the first post. See next post for a fast C version.

    ReplyDelete