Working on Benes network/testing
parent
abf4cc5e54
commit
37d1857f9c
|
@ -39,25 +39,22 @@ public class BenesNetwork implements PermutationNetwork
|
||||||
* @return the requested index
|
* @return the requested index
|
||||||
*/
|
*/
|
||||||
public static int getOutputIdxInPreviousLayer(int logN, int layer, int inputIdx) {
|
public static int getOutputIdxInPreviousLayer(int logN, int layer, int inputIdx) {
|
||||||
assert (layer > 0) && (layer < 2*logN - 1);
|
assert (layer >= 0) && (layer < 2*logN - 1);
|
||||||
assert (inputIdx >= 0) && (inputIdx < 1 << inputIdx);
|
assert (inputIdx >= 0) && (inputIdx < 1 << inputIdx);
|
||||||
|
|
||||||
if ((inputIdx & 1) == 0) {
|
if ((layer == 0) || (inputIdx & 1) == 0) {
|
||||||
// Even inputs are connected straight "across" everywhere
|
// layer 0 inputs and all even inputs
|
||||||
|
// are connected straight "across" everywhere
|
||||||
return inputIdx;
|
return inputIdx;
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Odd inputs are connected depending on layer ---
|
// --- Odd inputs are connected depending on layer ---
|
||||||
|
|
||||||
// In middle layer everything goes across
|
|
||||||
if (layer == logN)
|
|
||||||
return inputIdx;
|
|
||||||
|
|
||||||
int crossBit;
|
int crossBit;
|
||||||
if (layer < logN)
|
if (layer < logN)
|
||||||
crossBit = logN - layer;
|
crossBit = logN - layer;
|
||||||
else
|
else
|
||||||
crossBit = layer - logN;
|
crossBit = layer - logN + 1;
|
||||||
|
|
||||||
return inputIdx ^ (1 << crossBit);
|
return inputIdx ^ (1 << crossBit);
|
||||||
}
|
}
|
||||||
|
@ -73,6 +70,11 @@ public class BenesNetwork implements PermutationNetwork
|
||||||
return getOutputIdxInPreviousLayer(logN, layer + 1, outputIdx);
|
return getOutputIdxInPreviousLayer(logN, layer + 1, outputIdx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getInputIdxInNextLayer(int layer, int outputIdx) {
|
||||||
|
return getInputIdxInNextLayer(logN, layer, outputIdx);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int getNumInputs() {
|
public int getNumInputs() {
|
||||||
return 1 << logN;
|
return 1 << logN;
|
||||||
|
@ -111,11 +113,13 @@ public class BenesNetwork implements PermutationNetwork
|
||||||
* @param level
|
* @param level
|
||||||
*/
|
*/
|
||||||
private void setInternalPermutation(int[] permIn, int[] permOut, int idx, int level) {
|
private void setInternalPermutation(int[] permIn, int[] permOut, int idx, int level) {
|
||||||
// Make sure idx is in the the range [0, 2^{logN-level-1})
|
// Make sure idx is in the the range [0, 2^{level})
|
||||||
assert (idx == (idx & ((1 << (logN - level - 1)) - 1)));
|
assert (idx < (1 << level));
|
||||||
assert (permOut.length == (1 << (logN - level)));
|
assert (permOut.length == (1 << (logN - level)));
|
||||||
assert (permIn.length == permOut.length);
|
assert (permIn.length == permOut.length);
|
||||||
|
|
||||||
|
final int Nover2 = 1 << (logN - level - 1);
|
||||||
|
|
||||||
// Base case
|
// Base case
|
||||||
if (level == logN - 1) {
|
if (level == logN - 1) {
|
||||||
if (permOut[0] == permIn[0])
|
if (permOut[0] == permIn[0])
|
||||||
|
@ -150,65 +154,80 @@ public class BenesNetwork implements PermutationNetwork
|
||||||
do {
|
do {
|
||||||
//assert ((i & 1) == 0); // Algorithm guarantees only even indices are unmatched at start of loop
|
//assert ((i & 1) == 0); // Algorithm guarantees only even indices are unmatched at start of loop
|
||||||
unmatchedIndices.remove(i);
|
unmatchedIndices.remove(i);
|
||||||
|
unmatchedIndices.remove(i ^ 1);
|
||||||
|
|
||||||
int switchNum = i >>> 1;
|
int switchNum = i >>> 1;
|
||||||
|
|
||||||
|
// index of outPerm[i] after the last switch layer.
|
||||||
|
int iSwitched;
|
||||||
|
|
||||||
if (((i & 1) == 0) && (switchNum < numSwitches / 2) || ((i & 1) == 1) && (switchNum >= numSwitches / 2)) {
|
if (((i & 1) == 0) && (switchNum < numSwitches / 2) || ((i & 1) == 1) && (switchNum >= numSwitches / 2)) {
|
||||||
// i is in the upper half and even, or in the lower half and odd, so
|
// i is in the upper half and even, or in the lower half and odd, so
|
||||||
// switch must be "straight" to get i to upper half.
|
// switch must be "straight" to get i to upper half.
|
||||||
switchValues[2 * logN - level - 1][blockStart + switchNum] = false;
|
switchValues[2 * logN - level - 2][blockStart + switchNum] = false;
|
||||||
|
iSwitched = i;
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
switchValues[2 * logN - level - 1][blockStart + switchNum] = true;
|
switchValues[2 * logN - level - 2][blockStart + switchNum] = true;
|
||||||
|
iSwitched = i ^ 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
int iConnectedTo = getOutputIdxInPreviousLayer(logN - level, 2 * (logN - level), i);
|
// Index of outPerm[i] in the output of the previous layer.
|
||||||
int iPairConnectedTo = getOutputIdxInPreviousLayer(logN - level, 2 * (logN - level), i ^ 1);
|
int iConnectedTo = getOutputIdxInPreviousLayer(logN - level, 2 * (logN - level) - 2, iSwitched);
|
||||||
|
int iPairConnectedTo = getOutputIdxInPreviousLayer(logN - level, 2 * (logN - level) - 2, iSwitched ^ 1);
|
||||||
|
|
||||||
upperPermOut[iConnectedTo] = permOut[i];
|
upperPermOut[iConnectedTo] = permOut[i];
|
||||||
lowerPermOut[iPairConnectedTo] = permOut[i ^ 1];
|
lowerPermOut[iPairConnectedTo - Nover2] = permOut[i ^ 1];
|
||||||
|
|
||||||
|
// Index of permOut[i] before the first switch layer.
|
||||||
int j = find(permOut[i], permIn);
|
int j = find(permOut[i], permIn);
|
||||||
int jSwitchNum = j >>> 1;
|
int jSwitchNum = j >>> 1;
|
||||||
int j
|
|
||||||
|
// Index of permOut[i] after the first switch layer
|
||||||
|
int jSwitched;
|
||||||
|
|
||||||
if ((((j & 1) == 0) && (jSwitchNum < numSwitches / 2)) || ((j & 1) == 1) && (jSwitchNum >= numSwitches / 2)) {
|
if ((((j & 1) == 0) && (jSwitchNum < numSwitches / 2)) || ((j & 1) == 1) && (jSwitchNum >= numSwitches / 2)) {
|
||||||
// Even output in the upper half, or odd output in the lower half
|
// Even output in the upper half, or odd output in the lower half
|
||||||
// so switch needs to be "straight" to get j to the upper half
|
// so switch needs to be "straight" to get j to the upper half
|
||||||
switchValues[level][blockStart + switchNum] = false;
|
switchValues[level][blockStart + switchNum] = false;
|
||||||
|
jSwitched = j;
|
||||||
} else {
|
} else {
|
||||||
// Otherwise switch needs to be "crossed" to get j to upper half
|
// Otherwise switch needs to be "crossed" to get j to upper half
|
||||||
switchValues[level][blockStart + switchNum] = true;
|
switchValues[level][blockStart + switchNum] = true;
|
||||||
j ^= 1;
|
jSwitched = j ^ 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
int jConnectedTo = getInputIdxInNextLayer(logN - level, 0, j);
|
int jConnectedTo = getInputIdxInNextLayer(logN - level, 0, jSwitched);
|
||||||
upperPermIn[jConnectedTo] = j;
|
// must be in upper half
|
||||||
|
assert(jConnectedTo < (1 << (logN - 1)));
|
||||||
|
upperPermIn[jConnectedTo] = permOut[i];
|
||||||
|
|
||||||
lowerPermIn[i] = permIn[i + 1];
|
// Connect from j's pair to through the lower half
|
||||||
|
int jPairConnectedTo = getInputIdxInNextLayer(logN - level, 0, jSwitched ^ 1);
|
||||||
|
lowerPermIn[jPairConnectedTo - Nover2] = permIn[j ^ 1];
|
||||||
|
|
||||||
|
int iPairNext = find(permIn[j ^ 1], permOut);
|
||||||
|
i = iPairNext ^ 1;
|
||||||
} while (unmatchedIndices.contains(i));
|
} while (unmatchedIndices.contains(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
setInternalPermutation(upperPermIn, upperPermOut, idx, level + 1);
|
setInternalPermutation(upperPermIn, upperPermOut, 2*idx, level + 1);
|
||||||
setInternalPermutation(lowerPermIn, lowerPermOut, idx, level + 1);
|
setInternalPermutation(lowerPermIn, lowerPermOut, 2*idx + 1, level + 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setPermutation(int[] permutation) {
|
public void setPermutation(int[] permutation) {
|
||||||
|
|
||||||
// We use the recursive algorithm attributed to ? in ?'s paper
|
|
||||||
int level = 0;
|
int level = 0;
|
||||||
|
int[] identity = new int[getNumInputs()];
|
||||||
|
for (int i = 0; i < identity.length; ++i)
|
||||||
|
identity[i] = i;
|
||||||
|
|
||||||
|
setInternalPermutation(identity, permutation, 0, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public <T> void permuteLayer(T[] values, int layer) {
|
public boolean isCrossed(int layer, int switchIdx) {
|
||||||
|
return switchValues[layer][switchIdx];
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public <T> void permuteLayer(ArrayList<T> values, int layer) {
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,8 +18,26 @@ public interface PermutationNetwork {
|
||||||
*/
|
*/
|
||||||
int getNumLayers();
|
int getNumLayers();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return the index of the switch output in the previous layer that connects to a specified input in the current layer.
|
||||||
|
* Each consecutive pair of outputs is from the same switch, and each consecutive pair of inputs goes to the same switch.
|
||||||
|
* (that is, switch j has inputs/outputs 2j and 2j+1)
|
||||||
|
*
|
||||||
|
* @param layer current layer index. Must be in the range [0, numLayers) (layer 0's previous layer consists of the inputs themselves)
|
||||||
|
* @param inputIdx the input Idx for the current layer. Must be in the range [0, numInputs)
|
||||||
|
*
|
||||||
|
* @return the requested index
|
||||||
|
*/
|
||||||
int getOutputIdxInPreviousLayer(int layer, int inputIdx);
|
int getOutputIdxInPreviousLayer(int layer, int inputIdx);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The inverse of {@link }#getOutputIdxInPreviousLayer(int,int)}.
|
||||||
|
* @param layer
|
||||||
|
* @param inputIdx
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
int getInputIdxInNextLayer(int layer, int inputIdx);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initialize switches to generate a specific permutation.
|
* Initialize switches to generate a specific permutation.
|
||||||
* @param permutation
|
* @param permutation
|
||||||
|
@ -28,12 +46,10 @@ public interface PermutationNetwork {
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Apply a single layer's permutation (as implied by {@link #setPermutation(int[])})
|
* Returns true iff the switch is crossed
|
||||||
* @param values values output by previous layer
|
* @param layer
|
||||||
* @param layer layer to apply (layer is between 1 and numLayers; layer 0 is the input layer)
|
* @return
|
||||||
* @param <T> type of values.
|
|
||||||
*/
|
*/
|
||||||
<T> void permuteLayer(T[] values, int layer);
|
boolean isCrossed(int layer, int switchIdx);
|
||||||
|
|
||||||
<T> void permuteLayer(ArrayList<T> values, int layer);
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,7 @@ import java.util.Random;
|
||||||
/**
|
/**
|
||||||
* Created by Tzlil on 12/17/2015.
|
* Created by Tzlil on 12/17/2015.
|
||||||
* container for random permutation
|
* container for random permutation
|
||||||
* the permutation is sated in constructor and can't be change
|
* the permutation is set in the constructor and can't be changed
|
||||||
*/
|
*/
|
||||||
public class RandomPermutation {
|
public class RandomPermutation {
|
||||||
public final int[] permutation;
|
public final int[] permutation;
|
||||||
|
@ -43,4 +43,19 @@ public class RandomPermutation {
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* randomly shuffle an array of elements in-place (using Fisher-Yates)
|
||||||
|
* @param perm array of elements to shuffle.
|
||||||
|
* @param random
|
||||||
|
*/
|
||||||
|
public static <T> void permute(T[] perm, Random random){
|
||||||
|
for (int i = perm.length - 1; i > 0; --i) {
|
||||||
|
int j = random.nextInt(i+1);
|
||||||
|
T tmp = perm[i];
|
||||||
|
perm[i] = perm[j];
|
||||||
|
perm[j] = tmp;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
package meerkat.mixer.mixing;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generic utilities
|
||||||
|
*/
|
||||||
|
public class Util {
|
||||||
|
/*
|
||||||
|
* Apply a single layer's switch permutations in place (as implied by {@link PermutationNetwork#setPermutation(int[])})
|
||||||
|
* @param values values output by previous layxer
|
||||||
|
* @param layer layer to apply (layer is between 1 and numLayers; layer 0 is the input layer)
|
||||||
|
* @param <T> type of values.
|
||||||
|
*/
|
||||||
|
public static <T> void applyLayerSwitches(PermutationNetwork net, T[] values, int layer) {
|
||||||
|
int numSwitches = net.getNumInputs() >>> 1;
|
||||||
|
for (int i = 0; i < numSwitches; ++i) {
|
||||||
|
if (net.isCrossed(layer, i)) {
|
||||||
|
T tmp = values[i * 2];
|
||||||
|
values[i * 2] = values[i * 2 + 1];
|
||||||
|
values[i * 2 + 1] = tmp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static <T> void convertOutputPermToInputPerm(PermutationNetwork net, T[] oldValues, T[] newValues, int layer) {
|
||||||
|
int numInputs = net.getNumInputs();
|
||||||
|
assert(numInputs == oldValues.length);
|
||||||
|
assert(numInputs == newValues.length);
|
||||||
|
|
||||||
|
for (int i = 0; i < numInputs; ++i) {
|
||||||
|
newValues[i] = oldValues[net.getOutputIdxInPreviousLayer(layer, i)];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static <T> void permute(PermutationNetwork net, T[] oldValues, T[] newValues) {
|
||||||
|
for (int layer = 0; layer < net.getNumLayers(); ++layer) {
|
||||||
|
convertOutputPermToInputPerm(net, oldValues, newValues, layer);
|
||||||
|
applyLayerSwitches(net, newValues, layer);
|
||||||
|
System.arraycopy(newValues, 0, oldValues, 0, oldValues.length);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -45,7 +45,7 @@ public class MixingTest extends ECParamTestBase {
|
||||||
mixer = new Mixer(prover, enc);
|
mixer = new Mixer(prover, enc);
|
||||||
|
|
||||||
// generate n
|
// generate n
|
||||||
int logN = 8; // + random.nextInt(8)
|
int logN = 9; // + random.nextInt(8)
|
||||||
layers = 2*logN - 1;
|
layers = 2*logN - 1;
|
||||||
n = 1 << logN;
|
n = 1 << logN;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,15 +1,20 @@
|
||||||
package meerkat.mixer.mixing;
|
package meerkat.mixer.mixing;
|
||||||
|
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.TreeSet;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Tests for Benes Network topology
|
* Tests for Benes Network topology
|
||||||
*/
|
*/
|
||||||
public class BenesNetworkTest {
|
public class BenesNetworkTest extends PermutationNetworkTest {
|
||||||
public static class Permutation {
|
final static int logN = 3;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected PermutationNetwork getNewNetwork() {
|
||||||
|
return new BenesNetwork(logN);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static isAPerm
|
|
||||||
|
|
||||||
}
|
}
|
|
@ -0,0 +1,109 @@
|
||||||
|
package meerkat.mixer.mixing;
|
||||||
|
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import java.util.Random;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.TreeSet;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertArrayEquals;
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generically Test a permutation network
|
||||||
|
*/
|
||||||
|
abstract public class PermutationNetworkTest {
|
||||||
|
|
||||||
|
public static final int NUM_REPS = 10;
|
||||||
|
Random rand;
|
||||||
|
|
||||||
|
abstract protected PermutationNetwork getNewNetwork();
|
||||||
|
|
||||||
|
PermutationNetwork network;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setup() {
|
||||||
|
network = getNewNetwork();
|
||||||
|
rand = new Random(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Set<Integer> getSequenceSet(int N) {
|
||||||
|
Set<Integer> set = new TreeSet<>();
|
||||||
|
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
set.add(i);
|
||||||
|
}
|
||||||
|
return set;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Integer[] getSequenceArray(int N) {
|
||||||
|
Integer[] arr = new Integer[N];
|
||||||
|
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
arr[i] = i;
|
||||||
|
}
|
||||||
|
return arr;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if a given network is actually a permutation network (i.e., always
|
||||||
|
* implies a permutation regardless of 2x2 switch settings).
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void isAlwaysAPermutation() {
|
||||||
|
int numLayers = network.getNumLayers();
|
||||||
|
int N = network.getNumInputs();
|
||||||
|
|
||||||
|
for (int layer = 1; layer < numLayers; ++layer) {
|
||||||
|
Set<Integer> unusedInputs = getSequenceSet(N);
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
unusedInputs.remove(network.getOutputIdxInPreviousLayer(layer, i));
|
||||||
|
}
|
||||||
|
assertTrue("Not a permutation! Didn't use: " + unusedInputs, unusedInputs.isEmpty());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void forwardEqualsBackwards() {
|
||||||
|
int numLayers = network.getNumLayers();
|
||||||
|
int N = network.getNumInputs();
|
||||||
|
|
||||||
|
for (int layer = 1; layer < numLayers; ++layer) {
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
int j = network.getOutputIdxInPreviousLayer(layer, i);
|
||||||
|
assertEquals(String.format("Input %d in layer %d has problems", i, layer), i,
|
||||||
|
network.getInputIdxInNextLayer(layer - 1, j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static int[] convert(Integer[] in) {
|
||||||
|
int[] out = new int[in.length];
|
||||||
|
for (int i = 0; i < in.length; ++i)
|
||||||
|
out[i] = in[i];
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testRandomPermutations() throws Exception {
|
||||||
|
for (int rep = 0; rep < NUM_REPS; ++rep) {
|
||||||
|
Integer[] target = getSequenceArray(network.getNumInputs());
|
||||||
|
RandomPermutation.permute(target, rand);
|
||||||
|
|
||||||
|
network.setPermutation(convert(target));
|
||||||
|
|
||||||
|
Integer[] id = getSequenceArray(network.getNumInputs());
|
||||||
|
|
||||||
|
Integer[] out = new Integer[target.length];
|
||||||
|
|
||||||
|
Util.permute(network, id, out);
|
||||||
|
|
||||||
|
assertArrayEquals("Permutation mismatch: " + target + " != " + out, target, out);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue