I've written a implementation of a merge sort in Python. It seems correct and I have tested it with this:
l = list(range(1000))
random.shuffle(l) # random.shuffle is in-place
mergesort(l) == list(range(1000)) # returns True
# Stable sort
cards = list(itertools.product(range(4), range(13)))
random.shuffle(cards)
cards = mergesort(cards, key=lambda x: x[1]) # Sort by rank
cards = mergesort(cards, key=lambda x: x[0]) # Sort by suit
cards == list(itertools.product(range(4), range(13))) # returns True
I've also tested its performance, comparing it to sorted
and the merge sort implementation in rosetta code:
rl = list(range(100))
random.shuffle(rl)
%timeit sorted(rl)
# 100000 loops, best of 3: 11.3 µs per loop
%timeit mergesort(rl) # My code
# 1000 loops, best of 3: 376 µs per loop
%timeit rcmerge_sort(rl) # From rosetta code
# 1000 loops, best of 3: 350 µs per loop
I'm looking for any suggestions on how to improve this code. I suspect there is a better way to do the mergelist
function, particularly in how I tried to avoid code duplication like:
if top_a <= top_b:
nl.append(top_a)
try:
top_a = next(it_a)
except:
...
else:
# duplicates above code
In my code I placed the iterators and first values in a list, then use the variable k
as index, but this leads to hacks like abs(k-1)
and placing magic numbers 0
and 1
in the code.
def mergesort(l, key=None):
# Split the list into sublists of length 1
sublists = [[x] for x in l]
while len(sublists) > 1:
new_sublists = []
# Create a generator that yields two sublists at a time
sublists_pairs = ((sublists[2*x], sublists[2*x+1])
for x in range(len(sublists)//2))
for a, b in sublists_pairs:
new_sublists.append(mergelist(a, b, key))
# If the length is odd, then there is one sublist that is not merged
if len(sublists) % 2 != 0:
new_sublists.append(sublists[-1])
sublists = new_sublists
return new_sublists[0]
def mergelist(a, b, key=None):
nl = []
# Iterators that yield values from a and b
its = iter(a), iter(b)
# The top of both lists
tops = [next(it) for it in its]
while True:
# Determine the iterator that the next element should be taken from
if key:
k = 0 if key(tops[0]) <= key(tops[1]) else 1
else:
k = 0 if tops[0] <= tops[1] else 1
nl.append(tops[k])
try:
# Update the top of the iterator
tops[k] = next(its[k])
except StopIteration:
# Unless the iterator is empty, in which case get the rest of
# the values from the other iterator
# abs(k-1) is similar to (0 if k == 1 else 1)
nl.append(tops[abs(k-1)])
for e in its[abs(k-1)]:
nl.append(e)
return nl