Playing with the Multi-threading Programming in Java
The multithreading programming could be tricky but fun. Many years ago I created a small and simple Java class which implements a simple load balancing for threads. Since then I haven’t found something similar in the Java standard libraries. So, I wrote it down again here for sharing. The code itself is quite self-explanatory. Its simplicity is the beauty. Your comment is highly appreciated.
package com.akira.lib;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
/**
* A queue contains a list of data elements for thread to process.
* The data elements are first come first served by the thread.
* The time for processing each data element in their threads may vary.
* The queue balances the load among the threads simply by
* having each thread come back and grab the next data element
* when the thread is read to take the next data element.
*
* @param <T> data for threads to process
*/
public class LoadBalanceQueue<T> {
private Queue<T> queue;
private Set<Thread> blockedThreads;
private Set<Thread> stoppedThreads;
private Lock lock;
private Condition objectAvailable;
public LoadBalanceQueue() {
super();
queue = new LinkedList<T>();
blockedThreads = new HashSet<Thread>();
stoppedThreads = new HashSet<Thread>();
lock = new ReentrantLock();
objectAvailable = lock.newCondition();
}
public int size() {
lock.lock();
try {
return queue.size();
}
finally {
lock.unlock();
}
}
/**
* take a data element from the queue for process
*
* @return if the Optional is not empty, then its value is the data element for your thread to process.
* Otherwise, you are asked to shutdown your thread.
*/
public Optional<T> take() {
lock.lock();
Thread currentThread = Thread.currentThread();
try {
while (queue.isEmpty()) {
blockedThreads.add(currentThread);
objectAvailable.await();
if (stoppedThreads.contains(currentThread)) {
return Optional.empty();
}
}
return Optional.of(queue.remove());
}
catch (InterruptedException e) {
return Optional.empty();
}
finally {
blockedThreads.remove(currentThread);
lock.unlock();
}
}
/**
* add a data element into the queue
* @param object a data element
*/
public void add(T object) {
lock.lock();
try {
queue.add(object);
objectAvailable.signal();
}
finally {
lock.unlock();
}
}
/**
* add a list of data elements into the queue
* @param objects a list of data elements
*/
public void addAll(List<T> objects) {
lock.lock();
try {
queue.addAll(objects);
objectAvailable.signalAll();
}
finally {
lock.unlock();
}
}
/**
*
* @return the number of the idle threads who are waiting for data elements
*/
public int getNumIdleThread() {
lock.lock();
try {
return blockedThreads.size();
}
finally {
lock.unlock();
}
}
/**
*
* @return the number of the threads waiting to shut down
*/
public int getNumStoppedThread() {
lock.lock();
try {
return stoppedThreads.size();
}
finally {
lock.unlock();
}
}
/**
* Notify the queue that the threads are about to shut down
* @param threads are about to shut down
*/
public void notifyThreadShutdownStart(Set<Thread> threads) {
lock.lock();
try {
stoppedThreads.addAll(threads);
objectAvailable.signalAll();
}
finally {
lock.unlock();
}
}
/**
* Notify the queue that the threads have been shut down
* @param threads threads have been shut down
*/
public void notifyThreadShutdownEnd(Set<Thread> threads) {
lock.lock();
try {
stoppedThreads.removeAll(threads);
}
finally {
lock.unlock();
}
}
}
I also wrote a quick and dirty Junit test for it.
package com.akira.lib;
import org.junit.jupiter.api.Test;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
import static org.junit.jupiter.api.Assertions.assertEquals;
class LoadBalanceQueueTest {
class RunnableTester implements Runnable {
private LoadBalanceQueue<Integer> queue;
public RunnableTester(LoadBalanceQueue<Integer> queue) {
this.queue = queue;
}
@Override
public void run() {
String threadName = Thread.currentThread().getName();
try {
while (true) {
Optional<Integer> data = queue.take();
if (data.isPresent()) {
Thread.sleep(data.get() * 1000);
} else {
return;
}
}
} catch (InterruptedException e) {
return;
}
}
}
@Test
void testLoadBalanceQueue() {
try {
LoadBalanceQueue<Integer> queue = new LoadBalanceQueue<Integer>();
for (int i = 1; i <= 10; i++) {
queue.add(i);
}
Set<Thread> threads = new HashSet<>();
for (int i = 0; i < 4; i++) {
RunnableTester runnableTester = new RunnableTester(queue);
Thread thread = new Thread(runnableTester);
thread.setName("LoadBalanceThread: " + i);
threads.add(thread);
thread.start();
Thread.sleep(1000);
}
assertEquals(0, queue.getNumIdleThread());
assertEquals(0, queue.getNumStoppedThread());
int numOfThreads = threads.size();
while (true) {
if (queue.getNumIdleThread() < numOfThreads) {
Thread.sleep(2000);
} else {
assertEquals(4, queue.getNumIdleThread());
assertEquals(0, queue.getNumStoppedThread());
queue.notifyThreadShutdownStart(threads);
assertEquals(4, queue.getNumStoppedThread());
break;
}
}
for (int i = 0; i < 4; i++) {
if (queue.getNumIdleThread() > 0) {
Thread.sleep(500);
}
}
assertEquals(0, queue.getNumIdleThread());
queue.notifyThreadShutdownEnd(threads);
assertEquals(0, queue.getNumStoppedThread());
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}