From 0d416f0018f8a1f5414f65422635fa15023d321e Mon Sep 17 00:00:00 2001 From: Tal Moran Date: Wed, 18 Jan 2017 11:47:52 +0200 Subject: [PATCH] Rewriting Benes network code (in progress) --- .../meerkat/mixer/mixing/BenesNetwork.java | 214 ++++++++++++++++++ .../java/meerkat/mixer/mixing/MixNetwork.java | 2 +- .../mixer/mixing/PermutationNetwork.java | 39 ++++ .../mixer/mixing/RandomPermutation.java | 25 +- .../mixer/mixing/BenesNetworkTest.java | 15 ++ 5 files changed, 282 insertions(+), 13 deletions(-) create mode 100644 mixer/src/main/java/meerkat/mixer/mixing/BenesNetwork.java create mode 100644 mixer/src/main/java/meerkat/mixer/mixing/PermutationNetwork.java create mode 100644 mixer/src/test/java/meerkat/mixer/mixing/BenesNetworkTest.java diff --git a/mixer/src/main/java/meerkat/mixer/mixing/BenesNetwork.java b/mixer/src/main/java/meerkat/mixer/mixing/BenesNetwork.java new file mode 100644 index 0000000..c4bfd28 --- /dev/null +++ b/mixer/src/main/java/meerkat/mixer/mixing/BenesNetwork.java @@ -0,0 +1,214 @@ +package meerkat.mixer.mixing; + +import java.util.*; + +/** + * Generate the Benes Network + */ +public class BenesNetwork implements PermutationNetwork +{ + /** + * log (base 2) of number of inputs to Benes network (N = 2^{logN}) + */ + final int logN; + + /** + * Values of switches set for a specific permutation. + * switchValues[layer][switchNum] is true if the corresponding switch applies the identity permutation on its inputs. + * layer can be in the range [0, 2*logN-1), switchNum in the range (0, 2^{logN-1}-1) + */ + final boolean[][] switchValues; + + public BenesNetwork(int logN) { + this.logN = logN; + + switchValues = new boolean[2*logN - 1][]; + for (int i = 0; i < switchValues.length; ++i) + switchValues[i] = new boolean[1 << (logN - 1)]; + } + + /** + * Return the index of the switch output in the previous layer that connects to a specified input in the current layer. + * Each consecutive pair of outputs is from the same switch, and each consecutive pair of inputs goes to the same switch. + * (that is, switch j has inputs/outputs 2j and 2j+1) + * + * @param logN log (base 2) of number of inputs to Benes network (N = 2^{logN}) + * @param layer current layer index. Must be between 1 and 2*logN-2 (layer 0 doesn't have a previous layer) + * @param inputIdx the input Idx for the current layer (must be between 0 and (1 << logN) - 1 + * + * @return the requested index + */ + public static int getOutputIdxInPreviousLayer(int logN, int layer, int inputIdx) { + assert (layer > 0) && (layer < 2*logN - 1); + assert (inputIdx >= 0) && (inputIdx < 1 << inputIdx); + + if ((inputIdx & 1) == 0) { + // Even inputs are connected straight "across" everywhere + return inputIdx; + } + + // --- Odd inputs are connected depending on layer --- + + // In middle layer everything goes across + if (layer == logN) + return inputIdx; + + int crossBit; + if (layer < logN) + crossBit = logN - layer; + else + crossBit = layer - logN; + + return inputIdx ^ (1 << crossBit); + } + + /** + * Inverse of {@link #getOutputIdxInPreviousLayer(int, int, int)} + * @param logN + * @param layer + * @param outputIdx + * @return + */ + public static int getInputIdxInNextLayer(int logN, int layer, int outputIdx) { + return getOutputIdxInPreviousLayer(logN, layer + 1, outputIdx); + } + + @Override + public int getNumInputs() { + return 1 << logN; + } + + @Override + public int getNumLayers() { + return 2*logN - 1; + } + + @Override + public int getOutputIdxInPreviousLayer(int layer, int inputIdx) { + return getOutputIdxInPreviousLayer(logN, layer, inputIdx); + } + + + /** + * Find in array + * TODO: replace with more efficient data structure if this becomes a performance bottleneck + * @param val + * @param arr + * @return + */ + static private int find(int val, int[] arr) { + for (int i = 0; i < arr.length; ++i) { + if (arr[i] == val) + return i; + } + return -1; + } + + /** + * The recursive algorithm attributed to H. Stone by A. Waksman + * (see pg. 161 in paper) + * @param permOut + * @param level + */ + private void setInternalPermutation(int[] permIn, int[] permOut, int idx, int level) { + // Make sure idx is in the the range [0, 2^{logN-level-1}) + assert (idx == (idx & ((1 << (logN - level - 1)) - 1))); + assert (permOut.length == (1 << (logN - level))); + assert (permIn.length == permOut.length); + + // Base case + if (level == logN - 1) { + if (permOut[0] == permIn[0]) + switchValues[level][idx] = false; + else + switchValues[level][idx] = true; + return; + } + + int nextLen = permOut.length >>> 1; + int[] upperPermIn = new int[nextLen]; + int[] lowerPermIn = new int[nextLen]; + + int[] upperPermOut = new int[nextLen]; + int[] lowerPermOut = new int[nextLen]; + + + SortedSet unmatchedIndices = new TreeSet<>(); + for (int i = 0; i < permIn.length; ++i) { + unmatchedIndices.add(i); + } + + /** + * Where the set of switches for this level and index actually starts in switchValues + */ + int blockStart = idx << level; + int numSwitches = 1 << (logN - level - 1); + + while (!unmatchedIndices.isEmpty()) { + int i = unmatchedIndices.first(); + + do { + //assert ((i & 1) == 0); // Algorithm guarantees only even indices are unmatched at start of loop + unmatchedIndices.remove(i); + + int switchNum = i >>> 1; + + if (((i & 1) == 0) && (switchNum < numSwitches / 2) || ((i & 1) == 1) && (switchNum >= numSwitches / 2)) { + // i is in the upper half and even, or in the lower half and odd, so + // switch must be "straight" to get i to upper half. + switchValues[2 * logN - level - 1][blockStart + switchNum] = false; + + } else { + switchValues[2 * logN - level - 1][blockStart + switchNum] = true; + } + + int iConnectedTo = getOutputIdxInPreviousLayer(logN - level, 2 * (logN - level), i); + int iPairConnectedTo = getOutputIdxInPreviousLayer(logN - level, 2 * (logN - level), i ^ 1); + + upperPermOut[iConnectedTo] = permOut[i]; + lowerPermOut[iPairConnectedTo] = permOut[i ^ 1]; + + int j = find(permOut[i], permIn); + int jSwitchNum = j >>> 1; + int j + + if ((((j & 1) == 0) && (jSwitchNum < numSwitches / 2)) || ((j & 1) == 1) && (jSwitchNum >= numSwitches / 2)) { + // Even output in the upper half, or odd output in the lower half + // so switch needs to be "straight" to get j to the upper half + switchValues[level][blockStart + switchNum] = false; + } else { + // Otherwise switch needs to be "crossed" to get j to upper half + switchValues[level][blockStart + switchNum] = true; + j ^= 1; + } + + int jConnectedTo = getInputIdxInNextLayer(logN - level, 0, j); + upperPermIn[jConnectedTo] = j; + + lowerPermIn[i] = permIn[i + 1]; + + } while (unmatchedIndices.contains(i)); + } + + setInternalPermutation(upperPermIn, upperPermOut, idx, level + 1); + setInternalPermutation(lowerPermIn, lowerPermOut, idx, level + 1); + } + + @Override + public void setPermutation(int[] permutation) { + + // We use the recursive algorithm attributed to ? in ?'s paper + int level = 0; + + } + + @Override + public void permuteLayer(T[] values, int layer) { + + } + + @Override + public void permuteLayer(ArrayList values, int layer) { + + } +} diff --git a/mixer/src/main/java/meerkat/mixer/mixing/MixNetwork.java b/mixer/src/main/java/meerkat/mixer/mixing/MixNetwork.java index e820ebd..99f041e 100644 --- a/mixer/src/main/java/meerkat/mixer/mixing/MixNetwork.java +++ b/mixer/src/main/java/meerkat/mixer/mixing/MixNetwork.java @@ -23,7 +23,7 @@ public class MixNetwork { } /** - * implements benes mix network algorithm + * implements Benes mix network algorithm * @param permutation - random permutation * @return switches */ diff --git a/mixer/src/main/java/meerkat/mixer/mixing/PermutationNetwork.java b/mixer/src/main/java/meerkat/mixer/mixing/PermutationNetwork.java new file mode 100644 index 0000000..7a224ab --- /dev/null +++ b/mixer/src/main/java/meerkat/mixer/mixing/PermutationNetwork.java @@ -0,0 +1,39 @@ +package meerkat.mixer.mixing; + +import java.util.ArrayList; + +/** + * A generic permutation network composed of 2-input switches. + */ +public interface PermutationNetwork { + /** + * Return the number of inputs supported by the network. + * @return + */ + int getNumInputs(); + + /** + * Return the number of layers supported by the network. + * @return + */ + int getNumLayers(); + + int getOutputIdxInPreviousLayer(int layer, int inputIdx); + + /** + * Initialize switches to generate a specific permutation. + * @param permutation + */ + void setPermutation(int[] permutation); + + + /** + * Apply a single layer's permutation (as implied by {@link #setPermutation(int[])}) + * @param values values output by previous layer + * @param layer layer to apply (layer is between 1 and numLayers; layer 0 is the input layer) + * @param type of values. + */ + void permuteLayer(T[] values, int layer); + + void permuteLayer(ArrayList values, int layer); +} diff --git a/mixer/src/main/java/meerkat/mixer/mixing/RandomPermutation.java b/mixer/src/main/java/meerkat/mixer/mixing/RandomPermutation.java index 7f18530..85482ce 100644 --- a/mixer/src/main/java/meerkat/mixer/mixing/RandomPermutation.java +++ b/mixer/src/main/java/meerkat/mixer/mixing/RandomPermutation.java @@ -17,7 +17,7 @@ public class RandomPermutation { * @param n permutation size * @param random */ - public RandomPermutation(int n,Random random) { + public RandomPermutation(int n, Random random) { this.permutation = generatePermutation(n,random); } @@ -27,19 +27,20 @@ public class RandomPermutation { * @param random * @return permutation */ - private int[] generatePermutation(int n,Random random){ - List numbers= new ArrayList(n); - for (int i = 0; i < n; i++) { - numbers.add(i); + public static int[] generatePermutation(int n, Random random){ + int[] result = new int[n]; + + // initialize and permute in one pass using "inside-out" Fisher-Yates + for (int i = 0; i < n; ++i) { + int j = random.nextInt(i+1); + + if (j != i) { + result[i] = result[j]; + } + + result[j] = i; } - int[] result = new int[n]; - int index; - for (int i = 0; i < n; i++) { - index = random.nextInt(n - i); - result[i] = numbers.get(index); - numbers.remove(index); - } return result; } } diff --git a/mixer/src/test/java/meerkat/mixer/mixing/BenesNetworkTest.java b/mixer/src/test/java/meerkat/mixer/mixing/BenesNetworkTest.java new file mode 100644 index 0000000..1f6be55 --- /dev/null +++ b/mixer/src/test/java/meerkat/mixer/mixing/BenesNetworkTest.java @@ -0,0 +1,15 @@ +package meerkat.mixer.mixing; + +import static org.junit.Assert.*; + +/** + * Tests for Benes Network topology + */ +public class BenesNetworkTest { + public static class Permutation { + + } + + public static isAPerm + +} \ No newline at end of file