From 37d1857f9c74e10c6d9fc02295f5b59a640dbbb9 Mon Sep 17 00:00:00 2001 From: Tal Moran Date: Wed, 18 Jan 2017 21:54:42 +0200 Subject: [PATCH] Working on Benes network/testing --- .../meerkat/mixer/mixing/BenesNetwork.java | 79 ++++++++----- .../mixer/mixing/PermutationNetwork.java | 28 ++++- .../mixer/mixing/RandomPermutation.java | 17 ++- .../main/java/meerkat/mixer/mixing/Util.java | 45 ++++++++ .../test/java/meerkat/mixer/MixingTest.java | 2 +- .../mixer/mixing/BenesNetworkTest.java | 15 ++- .../mixer/mixing/PermutationNetworkTest.java | 109 ++++++++++++++++++ 7 files changed, 252 insertions(+), 43 deletions(-) create mode 100644 mixer/src/main/java/meerkat/mixer/mixing/Util.java create mode 100644 mixer/src/test/java/meerkat/mixer/mixing/PermutationNetworkTest.java diff --git a/mixer/src/main/java/meerkat/mixer/mixing/BenesNetwork.java b/mixer/src/main/java/meerkat/mixer/mixing/BenesNetwork.java index c4bfd28..10cc45e 100644 --- a/mixer/src/main/java/meerkat/mixer/mixing/BenesNetwork.java +++ b/mixer/src/main/java/meerkat/mixer/mixing/BenesNetwork.java @@ -39,25 +39,22 @@ public class BenesNetwork implements PermutationNetwork * @return the requested index */ public static int getOutputIdxInPreviousLayer(int logN, int layer, int inputIdx) { - assert (layer > 0) && (layer < 2*logN - 1); + assert (layer >= 0) && (layer < 2*logN - 1); assert (inputIdx >= 0) && (inputIdx < 1 << inputIdx); - if ((inputIdx & 1) == 0) { - // Even inputs are connected straight "across" everywhere + if ((layer == 0) || (inputIdx & 1) == 0) { + // layer 0 inputs and all even inputs + // are connected straight "across" everywhere return inputIdx; } // --- Odd inputs are connected depending on layer --- - // In middle layer everything goes across - if (layer == logN) - return inputIdx; - int crossBit; if (layer < logN) crossBit = logN - layer; else - crossBit = layer - logN; + crossBit = layer - logN + 1; return inputIdx ^ (1 << crossBit); } @@ -73,6 +70,11 @@ public class BenesNetwork implements PermutationNetwork return getOutputIdxInPreviousLayer(logN, layer + 1, outputIdx); } + @Override + public int getInputIdxInNextLayer(int layer, int outputIdx) { + return getInputIdxInNextLayer(logN, layer, outputIdx); + } + @Override public int getNumInputs() { return 1 << logN; @@ -111,11 +113,13 @@ public class BenesNetwork implements PermutationNetwork * @param level */ private void setInternalPermutation(int[] permIn, int[] permOut, int idx, int level) { - // Make sure idx is in the the range [0, 2^{logN-level-1}) - assert (idx == (idx & ((1 << (logN - level - 1)) - 1))); + // Make sure idx is in the the range [0, 2^{level}) + assert (idx < (1 << level)); assert (permOut.length == (1 << (logN - level))); assert (permIn.length == permOut.length); + final int Nover2 = 1 << (logN - level - 1); + // Base case if (level == logN - 1) { if (permOut[0] == permIn[0]) @@ -150,65 +154,80 @@ public class BenesNetwork implements PermutationNetwork do { //assert ((i & 1) == 0); // Algorithm guarantees only even indices are unmatched at start of loop unmatchedIndices.remove(i); + unmatchedIndices.remove(i ^ 1); int switchNum = i >>> 1; + // index of outPerm[i] after the last switch layer. + int iSwitched; + if (((i & 1) == 0) && (switchNum < numSwitches / 2) || ((i & 1) == 1) && (switchNum >= numSwitches / 2)) { // i is in the upper half and even, or in the lower half and odd, so // switch must be "straight" to get i to upper half. - switchValues[2 * logN - level - 1][blockStart + switchNum] = false; + switchValues[2 * logN - level - 2][blockStart + switchNum] = false; + iSwitched = i; } else { - switchValues[2 * logN - level - 1][blockStart + switchNum] = true; + switchValues[2 * logN - level - 2][blockStart + switchNum] = true; + iSwitched = i ^ 1; } - int iConnectedTo = getOutputIdxInPreviousLayer(logN - level, 2 * (logN - level), i); - int iPairConnectedTo = getOutputIdxInPreviousLayer(logN - level, 2 * (logN - level), i ^ 1); + // Index of outPerm[i] in the output of the previous layer. + int iConnectedTo = getOutputIdxInPreviousLayer(logN - level, 2 * (logN - level) - 2, iSwitched); + int iPairConnectedTo = getOutputIdxInPreviousLayer(logN - level, 2 * (logN - level) - 2, iSwitched ^ 1); upperPermOut[iConnectedTo] = permOut[i]; - lowerPermOut[iPairConnectedTo] = permOut[i ^ 1]; + lowerPermOut[iPairConnectedTo - Nover2] = permOut[i ^ 1]; + // Index of permOut[i] before the first switch layer. int j = find(permOut[i], permIn); int jSwitchNum = j >>> 1; - int j + + // Index of permOut[i] after the first switch layer + int jSwitched; if ((((j & 1) == 0) && (jSwitchNum < numSwitches / 2)) || ((j & 1) == 1) && (jSwitchNum >= numSwitches / 2)) { // Even output in the upper half, or odd output in the lower half // so switch needs to be "straight" to get j to the upper half switchValues[level][blockStart + switchNum] = false; + jSwitched = j; } else { // Otherwise switch needs to be "crossed" to get j to upper half switchValues[level][blockStart + switchNum] = true; - j ^= 1; + jSwitched = j ^ 1; } - int jConnectedTo = getInputIdxInNextLayer(logN - level, 0, j); - upperPermIn[jConnectedTo] = j; + int jConnectedTo = getInputIdxInNextLayer(logN - level, 0, jSwitched); + // must be in upper half + assert(jConnectedTo < (1 << (logN - 1))); + upperPermIn[jConnectedTo] = permOut[i]; - lowerPermIn[i] = permIn[i + 1]; + // Connect from j's pair to through the lower half + int jPairConnectedTo = getInputIdxInNextLayer(logN - level, 0, jSwitched ^ 1); + lowerPermIn[jPairConnectedTo - Nover2] = permIn[j ^ 1]; + int iPairNext = find(permIn[j ^ 1], permOut); + i = iPairNext ^ 1; } while (unmatchedIndices.contains(i)); } - setInternalPermutation(upperPermIn, upperPermOut, idx, level + 1); - setInternalPermutation(lowerPermIn, lowerPermOut, idx, level + 1); + setInternalPermutation(upperPermIn, upperPermOut, 2*idx, level + 1); + setInternalPermutation(lowerPermIn, lowerPermOut, 2*idx + 1, level + 1); } @Override public void setPermutation(int[] permutation) { - - // We use the recursive algorithm attributed to ? in ?'s paper int level = 0; + int[] identity = new int[getNumInputs()]; + for (int i = 0; i < identity.length; ++i) + identity[i] = i; + setInternalPermutation(identity, permutation, 0, 0); } @Override - public void permuteLayer(T[] values, int layer) { - + public boolean isCrossed(int layer, int switchIdx) { + return switchValues[layer][switchIdx]; } - @Override - public void permuteLayer(ArrayList values, int layer) { - - } } diff --git a/mixer/src/main/java/meerkat/mixer/mixing/PermutationNetwork.java b/mixer/src/main/java/meerkat/mixer/mixing/PermutationNetwork.java index 7a224ab..4dc1e0e 100644 --- a/mixer/src/main/java/meerkat/mixer/mixing/PermutationNetwork.java +++ b/mixer/src/main/java/meerkat/mixer/mixing/PermutationNetwork.java @@ -18,8 +18,26 @@ public interface PermutationNetwork { */ int getNumLayers(); + /** + * Return the index of the switch output in the previous layer that connects to a specified input in the current layer. + * Each consecutive pair of outputs is from the same switch, and each consecutive pair of inputs goes to the same switch. + * (that is, switch j has inputs/outputs 2j and 2j+1) + * + * @param layer current layer index. Must be in the range [0, numLayers) (layer 0's previous layer consists of the inputs themselves) + * @param inputIdx the input Idx for the current layer. Must be in the range [0, numInputs) + * + * @return the requested index + */ int getOutputIdxInPreviousLayer(int layer, int inputIdx); + /** + * The inverse of {@link }#getOutputIdxInPreviousLayer(int,int)}. + * @param layer + * @param inputIdx + * @return + */ + int getInputIdxInNextLayer(int layer, int inputIdx); + /** * Initialize switches to generate a specific permutation. * @param permutation @@ -28,12 +46,10 @@ public interface PermutationNetwork { /** - * Apply a single layer's permutation (as implied by {@link #setPermutation(int[])}) - * @param values values output by previous layer - * @param layer layer to apply (layer is between 1 and numLayers; layer 0 is the input layer) - * @param type of values. + * Returns true iff the switch is crossed + * @param layer + * @return */ - void permuteLayer(T[] values, int layer); + boolean isCrossed(int layer, int switchIdx); - void permuteLayer(ArrayList values, int layer); } diff --git a/mixer/src/main/java/meerkat/mixer/mixing/RandomPermutation.java b/mixer/src/main/java/meerkat/mixer/mixing/RandomPermutation.java index 85482ce..537fae2 100644 --- a/mixer/src/main/java/meerkat/mixer/mixing/RandomPermutation.java +++ b/mixer/src/main/java/meerkat/mixer/mixing/RandomPermutation.java @@ -7,7 +7,7 @@ import java.util.Random; /** * Created by Tzlil on 12/17/2015. * container for random permutation - * the permutation is sated in constructor and can't be change + * the permutation is set in the constructor and can't be changed */ public class RandomPermutation { public final int[] permutation; @@ -43,4 +43,19 @@ public class RandomPermutation { return result; } + + + /** + * randomly shuffle an array of elements in-place (using Fisher-Yates) + * @param perm array of elements to shuffle. + * @param random + */ + public static void permute(T[] perm, Random random){ + for (int i = perm.length - 1; i > 0; --i) { + int j = random.nextInt(i+1); + T tmp = perm[i]; + perm[i] = perm[j]; + perm[j] = tmp; + } + } } diff --git a/mixer/src/main/java/meerkat/mixer/mixing/Util.java b/mixer/src/main/java/meerkat/mixer/mixing/Util.java new file mode 100644 index 0000000..26437ae --- /dev/null +++ b/mixer/src/main/java/meerkat/mixer/mixing/Util.java @@ -0,0 +1,45 @@ +package meerkat.mixer.mixing; + +import java.util.ArrayList; + +/** + * Generic utilities + */ +public class Util { + /* + * Apply a single layer's switch permutations in place (as implied by {@link PermutationNetwork#setPermutation(int[])}) + * @param values values output by previous layxer + * @param layer layer to apply (layer is between 1 and numLayers; layer 0 is the input layer) + * @param type of values. + */ + public static void applyLayerSwitches(PermutationNetwork net, T[] values, int layer) { + int numSwitches = net.getNumInputs() >>> 1; + for (int i = 0; i < numSwitches; ++i) { + if (net.isCrossed(layer, i)) { + T tmp = values[i * 2]; + values[i * 2] = values[i * 2 + 1]; + values[i * 2 + 1] = tmp; + } + } + } + + + public static void convertOutputPermToInputPerm(PermutationNetwork net, T[] oldValues, T[] newValues, int layer) { + int numInputs = net.getNumInputs(); + assert(numInputs == oldValues.length); + assert(numInputs == newValues.length); + + for (int i = 0; i < numInputs; ++i) { + newValues[i] = oldValues[net.getOutputIdxInPreviousLayer(layer, i)]; + } + } + + + public static void permute(PermutationNetwork net, T[] oldValues, T[] newValues) { + for (int layer = 0; layer < net.getNumLayers(); ++layer) { + convertOutputPermToInputPerm(net, oldValues, newValues, layer); + applyLayerSwitches(net, newValues, layer); + System.arraycopy(newValues, 0, oldValues, 0, oldValues.length); + } + } +} diff --git a/mixer/src/test/java/meerkat/mixer/MixingTest.java b/mixer/src/test/java/meerkat/mixer/MixingTest.java index 87eb4e8..df5b8aa 100644 --- a/mixer/src/test/java/meerkat/mixer/MixingTest.java +++ b/mixer/src/test/java/meerkat/mixer/MixingTest.java @@ -45,7 +45,7 @@ public class MixingTest extends ECParamTestBase { mixer = new Mixer(prover, enc); // generate n - int logN = 8; // + random.nextInt(8) + int logN = 9; // + random.nextInt(8) layers = 2*logN - 1; n = 1 << logN; } diff --git a/mixer/src/test/java/meerkat/mixer/mixing/BenesNetworkTest.java b/mixer/src/test/java/meerkat/mixer/mixing/BenesNetworkTest.java index 1f6be55..7a64f58 100644 --- a/mixer/src/test/java/meerkat/mixer/mixing/BenesNetworkTest.java +++ b/mixer/src/test/java/meerkat/mixer/mixing/BenesNetworkTest.java @@ -1,15 +1,20 @@ package meerkat.mixer.mixing; +import org.junit.Test; + +import java.util.Set; +import java.util.TreeSet; + import static org.junit.Assert.*; /** * Tests for Benes Network topology */ -public class BenesNetworkTest { - public static class Permutation { +public class BenesNetworkTest extends PermutationNetworkTest { + final static int logN = 3; + @Override + protected PermutationNetwork getNewNetwork() { + return new BenesNetwork(logN); } - - public static isAPerm - } \ No newline at end of file diff --git a/mixer/src/test/java/meerkat/mixer/mixing/PermutationNetworkTest.java b/mixer/src/test/java/meerkat/mixer/mixing/PermutationNetworkTest.java new file mode 100644 index 0000000..2c3aa25 --- /dev/null +++ b/mixer/src/test/java/meerkat/mixer/mixing/PermutationNetworkTest.java @@ -0,0 +1,109 @@ +package meerkat.mixer.mixing; + +import org.junit.Before; +import org.junit.Test; + +import java.util.Random; +import java.util.Set; +import java.util.TreeSet; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * Generically Test a permutation network + */ +abstract public class PermutationNetworkTest { + + public static final int NUM_REPS = 10; + Random rand; + + abstract protected PermutationNetwork getNewNetwork(); + + PermutationNetwork network; + + @Before + public void setup() { + network = getNewNetwork(); + rand = new Random(1); + } + + public static Set getSequenceSet(int N) { + Set set = new TreeSet<>(); + + for (int i = 0; i < N; ++i) { + set.add(i); + } + return set; + } + + public static Integer[] getSequenceArray(int N) { + Integer[] arr = new Integer[N]; + + for (int i = 0; i < N; ++i) { + arr[i] = i; + } + return arr; + } + + + /** + * Check if a given network is actually a permutation network (i.e., always + * implies a permutation regardless of 2x2 switch settings). + */ + @Test + public void isAlwaysAPermutation() { + int numLayers = network.getNumLayers(); + int N = network.getNumInputs(); + + for (int layer = 1; layer < numLayers; ++layer) { + Set unusedInputs = getSequenceSet(N); + for (int i = 0; i < N; ++i) { + unusedInputs.remove(network.getOutputIdxInPreviousLayer(layer, i)); + } + assertTrue("Not a permutation! Didn't use: " + unusedInputs, unusedInputs.isEmpty()); + } + } + + @Test + public void forwardEqualsBackwards() { + int numLayers = network.getNumLayers(); + int N = network.getNumInputs(); + + for (int layer = 1; layer < numLayers; ++layer) { + for (int i = 0; i < N; ++i) { + int j = network.getOutputIdxInPreviousLayer(layer, i); + assertEquals(String.format("Input %d in layer %d has problems", i, layer), i, + network.getInputIdxInNextLayer(layer - 1, j)); + } + } + } + + public static int[] convert(Integer[] in) { + int[] out = new int[in.length]; + for (int i = 0; i < in.length; ++i) + out[i] = in[i]; + + return out; + } + + @Test + public void testRandomPermutations() throws Exception { + for (int rep = 0; rep < NUM_REPS; ++rep) { + Integer[] target = getSequenceArray(network.getNumInputs()); + RandomPermutation.permute(target, rand); + + network.setPermutation(convert(target)); + + Integer[] id = getSequenceArray(network.getNumInputs()); + + Integer[] out = new Integer[target.length]; + + Util.permute(network, id, out); + + assertArrayEquals("Permutation mismatch: " + target + " != " + out, target, out); + + } + } +}