Rewriting Benes network code (in progress)
parent
c2da5aa464
commit
0d416f0018
|
@ -0,0 +1,214 @@
|
|||
package meerkat.mixer.mixing;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* Generate the Benes Network
|
||||
*/
|
||||
public class BenesNetwork implements PermutationNetwork
|
||||
{
|
||||
/**
|
||||
* log (base 2) of number of inputs to Benes network (N = 2^{logN})
|
||||
*/
|
||||
final int logN;
|
||||
|
||||
/**
|
||||
* Values of switches set for a specific permutation.
|
||||
* switchValues[layer][switchNum] is true if the corresponding switch applies the identity permutation on its inputs.
|
||||
* layer can be in the range [0, 2*logN-1), switchNum in the range (0, 2^{logN-1}-1)
|
||||
*/
|
||||
final boolean[][] switchValues;
|
||||
|
||||
public BenesNetwork(int logN) {
|
||||
this.logN = logN;
|
||||
|
||||
switchValues = new boolean[2*logN - 1][];
|
||||
for (int i = 0; i < switchValues.length; ++i)
|
||||
switchValues[i] = new boolean[1 << (logN - 1)];
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the index of the switch output in the previous layer that connects to a specified input in the current layer.
|
||||
* Each consecutive pair of outputs is from the same switch, and each consecutive pair of inputs goes to the same switch.
|
||||
* (that is, switch j has inputs/outputs 2j and 2j+1)
|
||||
*
|
||||
* @param logN log (base 2) of number of inputs to Benes network (N = 2^{logN})
|
||||
* @param layer current layer index. Must be between 1 and 2*logN-2 (layer 0 doesn't have a previous layer)
|
||||
* @param inputIdx the input Idx for the current layer (must be between 0 and (1 << logN) - 1
|
||||
*
|
||||
* @return the requested index
|
||||
*/
|
||||
public static int getOutputIdxInPreviousLayer(int logN, int layer, int inputIdx) {
|
||||
assert (layer > 0) && (layer < 2*logN - 1);
|
||||
assert (inputIdx >= 0) && (inputIdx < 1 << inputIdx);
|
||||
|
||||
if ((inputIdx & 1) == 0) {
|
||||
// Even inputs are connected straight "across" everywhere
|
||||
return inputIdx;
|
||||
}
|
||||
|
||||
// --- Odd inputs are connected depending on layer ---
|
||||
|
||||
// In middle layer everything goes across
|
||||
if (layer == logN)
|
||||
return inputIdx;
|
||||
|
||||
int crossBit;
|
||||
if (layer < logN)
|
||||
crossBit = logN - layer;
|
||||
else
|
||||
crossBit = layer - logN;
|
||||
|
||||
return inputIdx ^ (1 << crossBit);
|
||||
}
|
||||
|
||||
/**
|
||||
* Inverse of {@link #getOutputIdxInPreviousLayer(int, int, int)}
|
||||
* @param logN
|
||||
* @param layer
|
||||
* @param outputIdx
|
||||
* @return
|
||||
*/
|
||||
public static int getInputIdxInNextLayer(int logN, int layer, int outputIdx) {
|
||||
return getOutputIdxInPreviousLayer(logN, layer + 1, outputIdx);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getNumInputs() {
|
||||
return 1 << logN;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getNumLayers() {
|
||||
return 2*logN - 1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getOutputIdxInPreviousLayer(int layer, int inputIdx) {
|
||||
return getOutputIdxInPreviousLayer(logN, layer, inputIdx);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Find in array
|
||||
* TODO: replace with more efficient data structure if this becomes a performance bottleneck
|
||||
* @param val
|
||||
* @param arr
|
||||
* @return
|
||||
*/
|
||||
static private int find(int val, int[] arr) {
|
||||
for (int i = 0; i < arr.length; ++i) {
|
||||
if (arr[i] == val)
|
||||
return i;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
/**
|
||||
* The recursive algorithm attributed to H. Stone by A. Waksman
|
||||
* (see pg. 161 in <a href="http://grid.cs.gsu.edu/~wkim/index_files/permutation_network.pdf">paper</a>)
|
||||
* @param permOut
|
||||
* @param level
|
||||
*/
|
||||
private void setInternalPermutation(int[] permIn, int[] permOut, int idx, int level) {
|
||||
// Make sure idx is in the the range [0, 2^{logN-level-1})
|
||||
assert (idx == (idx & ((1 << (logN - level - 1)) - 1)));
|
||||
assert (permOut.length == (1 << (logN - level)));
|
||||
assert (permIn.length == permOut.length);
|
||||
|
||||
// Base case
|
||||
if (level == logN - 1) {
|
||||
if (permOut[0] == permIn[0])
|
||||
switchValues[level][idx] = false;
|
||||
else
|
||||
switchValues[level][idx] = true;
|
||||
return;
|
||||
}
|
||||
|
||||
int nextLen = permOut.length >>> 1;
|
||||
int[] upperPermIn = new int[nextLen];
|
||||
int[] lowerPermIn = new int[nextLen];
|
||||
|
||||
int[] upperPermOut = new int[nextLen];
|
||||
int[] lowerPermOut = new int[nextLen];
|
||||
|
||||
|
||||
SortedSet<Integer> unmatchedIndices = new TreeSet<>();
|
||||
for (int i = 0; i < permIn.length; ++i) {
|
||||
unmatchedIndices.add(i);
|
||||
}
|
||||
|
||||
/**
|
||||
* Where the set of switches for this level and index actually starts in switchValues
|
||||
*/
|
||||
int blockStart = idx << level;
|
||||
int numSwitches = 1 << (logN - level - 1);
|
||||
|
||||
while (!unmatchedIndices.isEmpty()) {
|
||||
int i = unmatchedIndices.first();
|
||||
|
||||
do {
|
||||
//assert ((i & 1) == 0); // Algorithm guarantees only even indices are unmatched at start of loop
|
||||
unmatchedIndices.remove(i);
|
||||
|
||||
int switchNum = i >>> 1;
|
||||
|
||||
if (((i & 1) == 0) && (switchNum < numSwitches / 2) || ((i & 1) == 1) && (switchNum >= numSwitches / 2)) {
|
||||
// i is in the upper half and even, or in the lower half and odd, so
|
||||
// switch must be "straight" to get i to upper half.
|
||||
switchValues[2 * logN - level - 1][blockStart + switchNum] = false;
|
||||
|
||||
} else {
|
||||
switchValues[2 * logN - level - 1][blockStart + switchNum] = true;
|
||||
}
|
||||
|
||||
int iConnectedTo = getOutputIdxInPreviousLayer(logN - level, 2 * (logN - level), i);
|
||||
int iPairConnectedTo = getOutputIdxInPreviousLayer(logN - level, 2 * (logN - level), i ^ 1);
|
||||
|
||||
upperPermOut[iConnectedTo] = permOut[i];
|
||||
lowerPermOut[iPairConnectedTo] = permOut[i ^ 1];
|
||||
|
||||
int j = find(permOut[i], permIn);
|
||||
int jSwitchNum = j >>> 1;
|
||||
int j
|
||||
|
||||
if ((((j & 1) == 0) && (jSwitchNum < numSwitches / 2)) || ((j & 1) == 1) && (jSwitchNum >= numSwitches / 2)) {
|
||||
// Even output in the upper half, or odd output in the lower half
|
||||
// so switch needs to be "straight" to get j to the upper half
|
||||
switchValues[level][blockStart + switchNum] = false;
|
||||
} else {
|
||||
// Otherwise switch needs to be "crossed" to get j to upper half
|
||||
switchValues[level][blockStart + switchNum] = true;
|
||||
j ^= 1;
|
||||
}
|
||||
|
||||
int jConnectedTo = getInputIdxInNextLayer(logN - level, 0, j);
|
||||
upperPermIn[jConnectedTo] = j;
|
||||
|
||||
lowerPermIn[i] = permIn[i + 1];
|
||||
|
||||
} while (unmatchedIndices.contains(i));
|
||||
}
|
||||
|
||||
setInternalPermutation(upperPermIn, upperPermOut, idx, level + 1);
|
||||
setInternalPermutation(lowerPermIn, lowerPermOut, idx, level + 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setPermutation(int[] permutation) {
|
||||
|
||||
// We use the recursive algorithm attributed to ? in ?'s paper
|
||||
int level = 0;
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> void permuteLayer(T[] values, int layer) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> void permuteLayer(ArrayList<T> values, int layer) {
|
||||
|
||||
}
|
||||
}
|
|
@ -23,7 +23,7 @@ public class MixNetwork {
|
|||
}
|
||||
|
||||
/**
|
||||
* implements benes mix network algorithm
|
||||
* implements Benes mix network algorithm
|
||||
* @param permutation - random permutation
|
||||
* @return switches
|
||||
*/
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
package meerkat.mixer.mixing;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
||||
/**
|
||||
* A generic permutation network composed of 2-input switches.
|
||||
*/
|
||||
public interface PermutationNetwork {
|
||||
/**
|
||||
* Return the number of inputs supported by the network.
|
||||
* @return
|
||||
*/
|
||||
int getNumInputs();
|
||||
|
||||
/**
|
||||
* Return the number of layers supported by the network.
|
||||
* @return
|
||||
*/
|
||||
int getNumLayers();
|
||||
|
||||
int getOutputIdxInPreviousLayer(int layer, int inputIdx);
|
||||
|
||||
/**
|
||||
* Initialize switches to generate a specific permutation.
|
||||
* @param permutation
|
||||
*/
|
||||
void setPermutation(int[] permutation);
|
||||
|
||||
|
||||
/**
|
||||
* Apply a single layer's permutation (as implied by {@link #setPermutation(int[])})
|
||||
* @param values values output by previous layer
|
||||
* @param layer layer to apply (layer is between 1 and numLayers; layer 0 is the input layer)
|
||||
* @param <T> type of values.
|
||||
*/
|
||||
<T> void permuteLayer(T[] values, int layer);
|
||||
|
||||
<T> void permuteLayer(ArrayList<T> values, int layer);
|
||||
}
|
|
@ -17,7 +17,7 @@ public class RandomPermutation {
|
|||
* @param n permutation size
|
||||
* @param random
|
||||
*/
|
||||
public RandomPermutation(int n,Random random) {
|
||||
public RandomPermutation(int n, Random random) {
|
||||
this.permutation = generatePermutation(n,random);
|
||||
}
|
||||
|
||||
|
@ -27,19 +27,20 @@ public class RandomPermutation {
|
|||
* @param random
|
||||
* @return permutation
|
||||
*/
|
||||
private int[] generatePermutation(int n,Random random){
|
||||
List<Integer> numbers= new ArrayList<Integer>(n);
|
||||
for (int i = 0; i < n; i++) {
|
||||
numbers.add(i);
|
||||
public static int[] generatePermutation(int n, Random random){
|
||||
int[] result = new int[n];
|
||||
|
||||
// initialize and permute in one pass using "inside-out" Fisher-Yates
|
||||
for (int i = 0; i < n; ++i) {
|
||||
int j = random.nextInt(i+1);
|
||||
|
||||
if (j != i) {
|
||||
result[i] = result[j];
|
||||
}
|
||||
|
||||
result[j] = i;
|
||||
}
|
||||
|
||||
int[] result = new int[n];
|
||||
int index;
|
||||
for (int i = 0; i < n; i++) {
|
||||
index = random.nextInt(n - i);
|
||||
result[i] = numbers.get(index);
|
||||
numbers.remove(index);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
package meerkat.mixer.mixing;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
* Tests for Benes Network topology
|
||||
*/
|
||||
public class BenesNetworkTest {
|
||||
public static class Permutation {
|
||||
|
||||
}
|
||||
|
||||
public static isAPerm
|
||||
|
||||
}
|
Loading…
Reference in New Issue