CyclicBarrier example: a parallel sort algorithm (ctd)
Having seen the general CyclicBarrier pattern
for performing a task such as a parallel sort, on the next couple of pages, we'll fill in the details and
actually implement the sort.
To implement the sort, I'd firstly suggest creating a "sorter" object that outside
callers create and call into. The sorter object will hold the current state of the
sort, including the intermediary "buckets", and parameters such as the number of threads.
If some of the other variables aren't clear, they hopefully will be by the end of this explanation.
public class ParallelSorter<E extends Comparable<E>> {
private final int noThreads =
Runtime.getRuntime().availableProcessors();
private final int noSamplesPerThread = 16;
private final AtomicLong randSeed =
new AtomicLong(System.nanoTime());
private volatile int stageNo = 0;
private final int dataSize;
private final List<E> data;
private final List<E> splitPoints =
new ArrayList<E>(noSamplesPerThread * noThreads);
private final List<List<E>> bucketsToSort;
private final ReadWriteLock dataLock;
private final CyclicBarrier barrier =
new CyclicBarrier(noThreads + 1, new Runnable() {
public void run() {
sortStageComplete();
}
});
public ParallelSorter(List data, ReadWriteLock dataLock) {
if (!(data instanceof RandomAccess))
throw new IllegalArgumentException("List must be random access");
this.data = data;
this.dataLock = dataLock;
this.dataSize = data.size();
List<List<E>> tempList = new ArrayList<List<E>>(noThreads);
for (int i = 0; i < noThreads; i++) {
tempList.add(new ArrayList(dataSize / noThreads));
}
bucketsToSort = Collections.unmodifiableList(tempList);
}
}
Safe data access from multiple threads
An important element to note is the ReadWriteLock passed to the constructor,
which we'll use to guard the data list. Unless we want to take a copy of it
(bearing in mind the whole point of a parallel sort is that the data list could be quite large),
we need some way of allowing multiple threads to access the data list safely,
and for one of those threads to be the calling thread. Using a ReadWriteLock (as
opposed to just any old Lock) gives us an advantage because there are several places
where we need concurrent reads and we're not expecting concurrent writes. (In this case,
a regular synchronized block, for example, would force threads to read one thread
at a time where this serialisation of reads would be unnecessary.)
The bucketsToSort field is a slightly nasty list of a list: we essentially want
one list per thread, with which list we use indexed on thread number. (Having a list within
a list, rather than an array of lists, makes the syntax a bit easier for dealing with
generics.) Once we've constructed a given thread's list, the actual list object
won't change, but the contents of the list will. In other words, the outer
List will only ever be read once it is constructed, and storing the
reference to the list in a final field is enough to give thread-safe access.
(The Collections.unmodifiableList() wrapper is just to ensure we don't
accidentally try and modify this list; it isn't really what adds the thread-safety.)
A given thread's individual list (i.e. one of the elements of bucketsToSort)
will be both read and written, and all accesses to one of these "inner" lists requires
synchronization on the individual list.
Sort worker thread
The ParallelSorter will then include the method sortStageComplete() previously mentioned,
plus the inner worker class; here, we simply fill in a few details:
private class SorterThread extends Thread {
private final int threadNo;
private volatile Throwable error;
SorterThread(int no) {
this.threadNo = no;
}
public void run() {
try {
double div = (double) dataSize / noThreads;
int startPos = (int) (div * threadNo),
endPos = (int) (div * (threadNo + 1));
gatherSplitPointSample(data, startPos, endPos);
barrier.await();
assignItemsToBuckets(data, threadNo, startPos, endPos);
barrier.await();
sortMyBucket();
barrier.await();
} catch (InterruptedException e) {
} catch (BrokenBarrierException e) {
} catch (Throwable t) {
this.error = t;
Thread.currentThread().interrupt();
try {
barrier.await();
} catch (Exception e) {}
}
}
private void sortMyBucket() {
List<E> l = bucketsToSort.get(threadNo);
synchronized (l) {
Collections.sort(l);
}
}
}
private void sortStageComplete() {
try {
switch (stageNo) {
case 0 : amalgamateSplitPointData(); break;
case 1 : clearData(); break;
case 2 : combineBuckets(); break;
default :
throw new RuntimeException("Don't expect to be "
+ " called at stage " + stageNo);
stageNo++;
} catch (RuntimeException rte) {
completionStageError = rte;
throw rte;
}
}
private volatile RuntimeException completionStageError;
Notice the wrapper to save any RuntimeException occurring during the
sortStageComplete() method, as discussed in the section on
error handling with CyclicBarrier.
Gathering split points
The first method we need is gatherSplitPointSample() from the
SorterThread's run() method. Recall that this must select noSamplesPerThread
items at random from the given portion of the data. For this, we'll simply select noSamplesPerThread random
numbers within our allocated index range, and accept that we could generate the same index
twice. Normally for random sample selection, this possibility of duplicate indexes wouldn't be
acceptable, and to select a given number of
elements at random, we'd use the correct technique of random
sampling. Correct random sampling has the disadvantage that it requires us to generate one random number per item
in the list. But if the number of samples is very small compared to the list size, then
the chance of duplicate index is negligible, and given our purpose here the occasional duplicate
would not matter.
To avoid too much contention on the shared splitPoints array,
we initially add our values to a local list and then add that local list to the shared one
at the end. (Arguably, amalgamating the per-thread list of split points should be
carried out in the amalgamateSplitPointData() method; we do it here simply
because it's a bit less complicated and probably doesn't make much overall different performance-wise.)
Each thread has its own Random object, just to avoid contention on a shared
generator; we use our own seed, incremented each time, just because we know in this case
that we'll create several Random instances from different threads in succession.
In most applications, the weak guarantee that the Random class offers of a different seed per instance is good enough, and we wouldn't go to the trouble of managing our own seed generation.
In the end, the gatherSplitPointSample method ends up as follows:
private void gatherSplitPointSample(List data, int startPos, int endPos) {
Random rand = new Random(randSeed.getAndAdd(17));
List sample = new ArrayList(noSamplesPerThread);
Lock l = dataLock.readLock();
l.lock();
try {
for (int i = 0; i < noSamplesPerThread; i++) {
int n = rand.nextInt(endPos - startPos) + startPos;
sample.add(data.get(n));
}
} finally {
l.unlock();
}
synchronized (splitPoints) {
splitPoints.addAll(sample);
}
}
Notice the call to dataLock.readLock() to fetch the read part of
the lock, and then the lock() and unlock() calls. At this stage, we
know that all access to the data are reads (and so multiple threads holding read locks won't stall one
another). So we happily hold on to the lock over
the entire loop rather than just around the call to data.get() which
is where it is strictly necessary.
Amalgamating the split point data— which you'll recall is executing in a
single thread after each thread has gathered its sample— is a simple question of
sorting the sample values, and then taking the 16th, 32nd etc values (because noSamplesPerThread
is 16 in our case). In the following implementation, we copy the sorted sample into a temporary
list, then put just the required samples back into the splitPoints list.
The code's pretty much what you'd expect. Remember, we must still synchronize
for data visibility reasons, because other threads have been and will be
accessing splitPoints:
private void amalgamateSplitPointData() {
synchronized (splitPoints) {
List spl = new ArrayList(splitPoints);
Collections.sort(spl);
splitPoints.clear();
for (int i = 1; i < noThreads; i++) {
splitPoints.add(spl.get(i * noSamplesPerThread));
}
}
}
Next: stages 2 and 3 of the sort
On the next page, we look at implementing stages 2 and 3 of the parallel sort.
If you enjoy this Java programming article, please share with friends and colleagues. Follow the author on Twitter for the latest news and rants.
Editorial page content written by Neil Coffey. Copyright © Javamex UK 2021. All rights reserved.