Rewriting Benes network code (in progress)
parent
c2da5aa464
commit
0d416f0018
|
@ -0,0 +1,214 @@
|
||||||
|
package meerkat.mixer.mixing;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate the Benes Network
|
||||||
|
*/
|
||||||
|
public class BenesNetwork implements PermutationNetwork
|
||||||
|
{
|
||||||
|
/**
|
||||||
|
* log (base 2) of number of inputs to Benes network (N = 2^{logN})
|
||||||
|
*/
|
||||||
|
final int logN;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Values of switches set for a specific permutation.
|
||||||
|
* switchValues[layer][switchNum] is true if the corresponding switch applies the identity permutation on its inputs.
|
||||||
|
* layer can be in the range [0, 2*logN-1), switchNum in the range (0, 2^{logN-1}-1)
|
||||||
|
*/
|
||||||
|
final boolean[][] switchValues;
|
||||||
|
|
||||||
|
public BenesNetwork(int logN) {
|
||||||
|
this.logN = logN;
|
||||||
|
|
||||||
|
switchValues = new boolean[2*logN - 1][];
|
||||||
|
for (int i = 0; i < switchValues.length; ++i)
|
||||||
|
switchValues[i] = new boolean[1 << (logN - 1)];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return the index of the switch output in the previous layer that connects to a specified input in the current layer.
|
||||||
|
* Each consecutive pair of outputs is from the same switch, and each consecutive pair of inputs goes to the same switch.
|
||||||
|
* (that is, switch j has inputs/outputs 2j and 2j+1)
|
||||||
|
*
|
||||||
|
* @param logN log (base 2) of number of inputs to Benes network (N = 2^{logN})
|
||||||
|
* @param layer current layer index. Must be between 1 and 2*logN-2 (layer 0 doesn't have a previous layer)
|
||||||
|
* @param inputIdx the input Idx for the current layer (must be between 0 and (1 << logN) - 1
|
||||||
|
*
|
||||||
|
* @return the requested index
|
||||||
|
*/
|
||||||
|
public static int getOutputIdxInPreviousLayer(int logN, int layer, int inputIdx) {
|
||||||
|
assert (layer > 0) && (layer < 2*logN - 1);
|
||||||
|
assert (inputIdx >= 0) && (inputIdx < 1 << inputIdx);
|
||||||
|
|
||||||
|
if ((inputIdx & 1) == 0) {
|
||||||
|
// Even inputs are connected straight "across" everywhere
|
||||||
|
return inputIdx;
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Odd inputs are connected depending on layer ---
|
||||||
|
|
||||||
|
// In middle layer everything goes across
|
||||||
|
if (layer == logN)
|
||||||
|
return inputIdx;
|
||||||
|
|
||||||
|
int crossBit;
|
||||||
|
if (layer < logN)
|
||||||
|
crossBit = logN - layer;
|
||||||
|
else
|
||||||
|
crossBit = layer - logN;
|
||||||
|
|
||||||
|
return inputIdx ^ (1 << crossBit);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Inverse of {@link #getOutputIdxInPreviousLayer(int, int, int)}
|
||||||
|
* @param logN
|
||||||
|
* @param layer
|
||||||
|
* @param outputIdx
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static int getInputIdxInNextLayer(int logN, int layer, int outputIdx) {
|
||||||
|
return getOutputIdxInPreviousLayer(logN, layer + 1, outputIdx);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getNumInputs() {
|
||||||
|
return 1 << logN;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getNumLayers() {
|
||||||
|
return 2*logN - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getOutputIdxInPreviousLayer(int layer, int inputIdx) {
|
||||||
|
return getOutputIdxInPreviousLayer(logN, layer, inputIdx);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Find in array
|
||||||
|
* TODO: replace with more efficient data structure if this becomes a performance bottleneck
|
||||||
|
* @param val
|
||||||
|
* @param arr
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
static private int find(int val, int[] arr) {
|
||||||
|
for (int i = 0; i < arr.length; ++i) {
|
||||||
|
if (arr[i] == val)
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The recursive algorithm attributed to H. Stone by A. Waksman
|
||||||
|
* (see pg. 161 in <a href="http://grid.cs.gsu.edu/~wkim/index_files/permutation_network.pdf">paper</a>)
|
||||||
|
* @param permOut
|
||||||
|
* @param level
|
||||||
|
*/
|
||||||
|
private void setInternalPermutation(int[] permIn, int[] permOut, int idx, int level) {
|
||||||
|
// Make sure idx is in the the range [0, 2^{logN-level-1})
|
||||||
|
assert (idx == (idx & ((1 << (logN - level - 1)) - 1)));
|
||||||
|
assert (permOut.length == (1 << (logN - level)));
|
||||||
|
assert (permIn.length == permOut.length);
|
||||||
|
|
||||||
|
// Base case
|
||||||
|
if (level == logN - 1) {
|
||||||
|
if (permOut[0] == permIn[0])
|
||||||
|
switchValues[level][idx] = false;
|
||||||
|
else
|
||||||
|
switchValues[level][idx] = true;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int nextLen = permOut.length >>> 1;
|
||||||
|
int[] upperPermIn = new int[nextLen];
|
||||||
|
int[] lowerPermIn = new int[nextLen];
|
||||||
|
|
||||||
|
int[] upperPermOut = new int[nextLen];
|
||||||
|
int[] lowerPermOut = new int[nextLen];
|
||||||
|
|
||||||
|
|
||||||
|
SortedSet<Integer> unmatchedIndices = new TreeSet<>();
|
||||||
|
for (int i = 0; i < permIn.length; ++i) {
|
||||||
|
unmatchedIndices.add(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Where the set of switches for this level and index actually starts in switchValues
|
||||||
|
*/
|
||||||
|
int blockStart = idx << level;
|
||||||
|
int numSwitches = 1 << (logN - level - 1);
|
||||||
|
|
||||||
|
while (!unmatchedIndices.isEmpty()) {
|
||||||
|
int i = unmatchedIndices.first();
|
||||||
|
|
||||||
|
do {
|
||||||
|
//assert ((i & 1) == 0); // Algorithm guarantees only even indices are unmatched at start of loop
|
||||||
|
unmatchedIndices.remove(i);
|
||||||
|
|
||||||
|
int switchNum = i >>> 1;
|
||||||
|
|
||||||
|
if (((i & 1) == 0) && (switchNum < numSwitches / 2) || ((i & 1) == 1) && (switchNum >= numSwitches / 2)) {
|
||||||
|
// i is in the upper half and even, or in the lower half and odd, so
|
||||||
|
// switch must be "straight" to get i to upper half.
|
||||||
|
switchValues[2 * logN - level - 1][blockStart + switchNum] = false;
|
||||||
|
|
||||||
|
} else {
|
||||||
|
switchValues[2 * logN - level - 1][blockStart + switchNum] = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int iConnectedTo = getOutputIdxInPreviousLayer(logN - level, 2 * (logN - level), i);
|
||||||
|
int iPairConnectedTo = getOutputIdxInPreviousLayer(logN - level, 2 * (logN - level), i ^ 1);
|
||||||
|
|
||||||
|
upperPermOut[iConnectedTo] = permOut[i];
|
||||||
|
lowerPermOut[iPairConnectedTo] = permOut[i ^ 1];
|
||||||
|
|
||||||
|
int j = find(permOut[i], permIn);
|
||||||
|
int jSwitchNum = j >>> 1;
|
||||||
|
int j
|
||||||
|
|
||||||
|
if ((((j & 1) == 0) && (jSwitchNum < numSwitches / 2)) || ((j & 1) == 1) && (jSwitchNum >= numSwitches / 2)) {
|
||||||
|
// Even output in the upper half, or odd output in the lower half
|
||||||
|
// so switch needs to be "straight" to get j to the upper half
|
||||||
|
switchValues[level][blockStart + switchNum] = false;
|
||||||
|
} else {
|
||||||
|
// Otherwise switch needs to be "crossed" to get j to upper half
|
||||||
|
switchValues[level][blockStart + switchNum] = true;
|
||||||
|
j ^= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int jConnectedTo = getInputIdxInNextLayer(logN - level, 0, j);
|
||||||
|
upperPermIn[jConnectedTo] = j;
|
||||||
|
|
||||||
|
lowerPermIn[i] = permIn[i + 1];
|
||||||
|
|
||||||
|
} while (unmatchedIndices.contains(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
setInternalPermutation(upperPermIn, upperPermOut, idx, level + 1);
|
||||||
|
setInternalPermutation(lowerPermIn, lowerPermOut, idx, level + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setPermutation(int[] permutation) {
|
||||||
|
|
||||||
|
// We use the recursive algorithm attributed to ? in ?'s paper
|
||||||
|
int level = 0;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public <T> void permuteLayer(T[] values, int layer) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public <T> void permuteLayer(ArrayList<T> values, int layer) {
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -23,7 +23,7 @@ public class MixNetwork {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* implements benes mix network algorithm
|
* implements Benes mix network algorithm
|
||||||
* @param permutation - random permutation
|
* @param permutation - random permutation
|
||||||
* @return switches
|
* @return switches
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
package meerkat.mixer.mixing;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A generic permutation network composed of 2-input switches.
|
||||||
|
*/
|
||||||
|
public interface PermutationNetwork {
|
||||||
|
/**
|
||||||
|
* Return the number of inputs supported by the network.
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
int getNumInputs();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return the number of layers supported by the network.
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
int getNumLayers();
|
||||||
|
|
||||||
|
int getOutputIdxInPreviousLayer(int layer, int inputIdx);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialize switches to generate a specific permutation.
|
||||||
|
* @param permutation
|
||||||
|
*/
|
||||||
|
void setPermutation(int[] permutation);
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Apply a single layer's permutation (as implied by {@link #setPermutation(int[])})
|
||||||
|
* @param values values output by previous layer
|
||||||
|
* @param layer layer to apply (layer is between 1 and numLayers; layer 0 is the input layer)
|
||||||
|
* @param <T> type of values.
|
||||||
|
*/
|
||||||
|
<T> void permuteLayer(T[] values, int layer);
|
||||||
|
|
||||||
|
<T> void permuteLayer(ArrayList<T> values, int layer);
|
||||||
|
}
|
|
@ -27,19 +27,20 @@ public class RandomPermutation {
|
||||||
* @param random
|
* @param random
|
||||||
* @return permutation
|
* @return permutation
|
||||||
*/
|
*/
|
||||||
private int[] generatePermutation(int n,Random random){
|
public static int[] generatePermutation(int n, Random random){
|
||||||
List<Integer> numbers= new ArrayList<Integer>(n);
|
int[] result = new int[n];
|
||||||
for (int i = 0; i < n; i++) {
|
|
||||||
numbers.add(i);
|
// initialize and permute in one pass using "inside-out" Fisher-Yates
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
int j = random.nextInt(i+1);
|
||||||
|
|
||||||
|
if (j != i) {
|
||||||
|
result[i] = result[j];
|
||||||
}
|
}
|
||||||
|
|
||||||
int[] result = new int[n];
|
result[j] = i;
|
||||||
int index;
|
|
||||||
for (int i = 0; i < n; i++) {
|
|
||||||
index = random.nextInt(n - i);
|
|
||||||
result[i] = numbers.get(index);
|
|
||||||
numbers.remove(index);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
package meerkat.mixer.mixing;
|
||||||
|
|
||||||
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tests for Benes Network topology
|
||||||
|
*/
|
||||||
|
public class BenesNetworkTest {
|
||||||
|
public static class Permutation {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public static isAPerm
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue