How AlphaDev is Changing the Game for Sorting Algorithms
In a recent paper titled "Faster sorting algorithms discovered using deep reinforcement learning," researchers from Google DeepMind introduced AlphaDev, a deep reinforcement learning (DRL) agent that can automatically discover correct and efficient sorting algorithms that achieve superior performance than previously known human benchmarks. Using deep reinforcement learning, AlphaDev learned to construct sorting code from simple CPU instructions to optimize speed. The team trained the algorithms on a variety of data sets and used a reward system to encourage the algorithms to sort the data as quickly and efficiently as possible.
The algorithm is already open-sourced and added to the main C++ library. It makes sorting 70% faster for shorter sequences and about 1.7% faster for sequences exceeding 250,000 elements.
What is AlphaDev?
AlphaDev is an extension of AlphaZero, a reinforcement learning model that defeated world champions in games like Go, chess and shogi. With AlphaDev, Google DeepMind has shown how this model can transfer from games to scientific challenges and from simulations to real-world applications. This agent is comprised of two core components:
Architecture
Each state in this game is defined as a vector St = ⟨Pt, Zt⟩ where
At timestep t, the player receives the current state St and executes an action at. This involves appending any legal assembly instruction (for example, mov<A,B>) to the current algorithm generated so far.
A reward rt is received that comprises both a measure of algorithm
The game is executed for a limited number of steps, after which the game is terminated. Winning the game corresponds to generating a correct, low-latency algorithm using assembly instructions. Losing the game corresponds to generating an incorrect algorithm or a correct but inefficient algorithm.
What is happening under the hood?
AlphaDev is a single-player game agent that uses an extension of the AlphaZero agent and guides a Monte Carlo tree search (MCTS) planning procedure using a deep neural network. The input to the neural network is the state St and the output is a
Policy prediction: The policy network is a neural network that provides the probability distribution of choosing the best move. The policy prediction is a prediction of the probability distribution over all possible moves in the current state St.
Value prediction : The value network is another neural network that provides value estimates as well as policy priors to Monte Carlo tree search (MCTS). The value prediction is a prediction of the cumulative returns R that the agent should expect to receive from the current state St.
The policy and value networks are trained together using a loss function that combines both policy and value losses. A lower policy loss indicates a more accurate selection of the best move, while a lower value loss indicates a more accurate prediction of the final outcome.
During a game, the agent receives as input the current state St. The agent then executes an MCTS procedure and uses this to select the next action to take. The generated games are then used to update the network’s parameters, enabling the agent to learn.
Encoders
Recommended by LinkedIn
To efficiently explore the space of instructions, AlphaDev has a representation capable of representing complex algorithmic structures. This representation network comprises two components:
Results
The confidence intervals are represented as latency ± (lower, upper), in which latency corresponds to the fifth percentile of latency measurements across 100 different machines. Lower and upper refer to the bounds of the 95% confidence interval for this percentile.
Fixed Sorting Algorithm
The paper discusses three fundamental algorithms: sort 3, sort 4 and sort 5. The state-of-the-art human benchmarks for these algorithms are sorting networks as they generate efficient, conditional branchless assembly code. Improving on these algorithms is challenging as they are already highly optimized. AlphaDev is able to find algorithms with fewer instructions than the human benchmarks for sort 3 and sort 5 and matches the state-of-the-art performance on sort 4. These shorter algorithms lead to lower latency as the algorithm length and latency are correlated for the conditional branchless.
New Algorithm by AlphaDev
It presents an optimal sorting network for three elements. The circled part of the network (last two comparators) can be seen as a sequence of instructions that takes an input sequence ⟨A, B, C⟩ and transforms each input as shown in Table 2a (left). However, a comparator on wires B and C precedes this operator and therefore input sequences where B ≤ C are guaranteed. This means that it is enough to compute min(A, B) as the first output instead of min(A, B, C) as shown in Table 2a (right). The pseudocode difference between Fig. 6 b,c demonstrates how the AlphaDev swap move saves one instruction each time it is applied.
2. AlphaDev copy move
Figure 7d presents a sorting network configuration, consisting of three comparators, that is applied across four wires. This configuration is found in a sort 8 sorting network and corresponds to an operator taking four inputs ⟨A, B, C, D⟩ and transforming them into four as seen in Table 2b (on the left). One can show that as part of sort 8, the input that flows into the operator satisfies the following inequality: D ≥ min(A, C). This means that the operator can be improved by applying the AlphaDev copy move that is defined in Table 2b (on the right), resulting in one instruction less than the original operator. The code difference between the original operator and the code after applying the AlphaDev copy move is visualized in Fig. 7e,f, respectively.
Variable Sort Algorithm
The paper discusses three variable sorting algorithms: VarSort3, VarSort4 and VarSort5. The human benchmark in each case is defined as an algorithm that calls the corresponding sorting network for a given input length.
The agent needs to determine how many sub algorithms it needs to construct and build the body of the main algorithm in parallel. The agent may also need to call sub algorithms from other sub algorithms. In this case, optimizing for length leads to significantly shorter algorithms compared to the human benchmarks.
New Algorithm by AlphaDev for VarSort4
This algorithm sorts sequences of length four, three or two numbers as input. If the length is two, then it calls the sort 2 sorting network and returns. If the length is three then it calls sort 3 to sort the first three numbers and returns. If the length is greater than three, then it calls sort 3, followed by a simplified sort 4 routine that sorts the remaining unsorted number.
References