Working on Benes network/testing

mixer
Tal Moran 2017-01-18 21:54:42 +02:00
parent abf4cc5e54
commit 37d1857f9c
7 changed files with 252 additions and 43 deletions

View File

@ -39,25 +39,22 @@ public class BenesNetwork implements PermutationNetwork
* @return the requested index
*/
public static int getOutputIdxInPreviousLayer(int logN, int layer, int inputIdx) {
assert (layer > 0) && (layer < 2*logN - 1);
assert (layer >= 0) && (layer < 2*logN - 1);
assert (inputIdx >= 0) && (inputIdx < 1 << inputIdx);
if ((inputIdx & 1) == 0) {
// Even inputs are connected straight "across" everywhere
if ((layer == 0) || (inputIdx & 1) == 0) {
// layer 0 inputs and all even inputs
// are connected straight "across" everywhere
return inputIdx;
}
// --- Odd inputs are connected depending on layer ---
// In middle layer everything goes across
if (layer == logN)
return inputIdx;
int crossBit;
if (layer < logN)
crossBit = logN - layer;
else
crossBit = layer - logN;
crossBit = layer - logN + 1;
return inputIdx ^ (1 << crossBit);
}
@ -73,6 +70,11 @@ public class BenesNetwork implements PermutationNetwork
return getOutputIdxInPreviousLayer(logN, layer + 1, outputIdx);
}
@Override
public int getInputIdxInNextLayer(int layer, int outputIdx) {
return getInputIdxInNextLayer(logN, layer, outputIdx);
}
@Override
public int getNumInputs() {
return 1 << logN;
@ -111,11 +113,13 @@ public class BenesNetwork implements PermutationNetwork
* @param level
*/
private void setInternalPermutation(int[] permIn, int[] permOut, int idx, int level) {
// Make sure idx is in the the range [0, 2^{logN-level-1})
assert (idx == (idx & ((1 << (logN - level - 1)) - 1)));
// Make sure idx is in the the range [0, 2^{level})
assert (idx < (1 << level));
assert (permOut.length == (1 << (logN - level)));
assert (permIn.length == permOut.length);
final int Nover2 = 1 << (logN - level - 1);
// Base case
if (level == logN - 1) {
if (permOut[0] == permIn[0])
@ -150,65 +154,80 @@ public class BenesNetwork implements PermutationNetwork
do {
//assert ((i & 1) == 0); // Algorithm guarantees only even indices are unmatched at start of loop
unmatchedIndices.remove(i);
unmatchedIndices.remove(i ^ 1);
int switchNum = i >>> 1;
// index of outPerm[i] after the last switch layer.
int iSwitched;
if (((i & 1) == 0) && (switchNum < numSwitches / 2) || ((i & 1) == 1) && (switchNum >= numSwitches / 2)) {
// i is in the upper half and even, or in the lower half and odd, so
// switch must be "straight" to get i to upper half.
switchValues[2 * logN - level - 1][blockStart + switchNum] = false;
switchValues[2 * logN - level - 2][blockStart + switchNum] = false;
iSwitched = i;
} else {
switchValues[2 * logN - level - 1][blockStart + switchNum] = true;
switchValues[2 * logN - level - 2][blockStart + switchNum] = true;
iSwitched = i ^ 1;
}
int iConnectedTo = getOutputIdxInPreviousLayer(logN - level, 2 * (logN - level), i);
int iPairConnectedTo = getOutputIdxInPreviousLayer(logN - level, 2 * (logN - level), i ^ 1);
// Index of outPerm[i] in the output of the previous layer.
int iConnectedTo = getOutputIdxInPreviousLayer(logN - level, 2 * (logN - level) - 2, iSwitched);
int iPairConnectedTo = getOutputIdxInPreviousLayer(logN - level, 2 * (logN - level) - 2, iSwitched ^ 1);
upperPermOut[iConnectedTo] = permOut[i];
lowerPermOut[iPairConnectedTo] = permOut[i ^ 1];
lowerPermOut[iPairConnectedTo - Nover2] = permOut[i ^ 1];
// Index of permOut[i] before the first switch layer.
int j = find(permOut[i], permIn);
int jSwitchNum = j >>> 1;
int j
// Index of permOut[i] after the first switch layer
int jSwitched;
if ((((j & 1) == 0) && (jSwitchNum < numSwitches / 2)) || ((j & 1) == 1) && (jSwitchNum >= numSwitches / 2)) {
// Even output in the upper half, or odd output in the lower half
// so switch needs to be "straight" to get j to the upper half
switchValues[level][blockStart + switchNum] = false;
jSwitched = j;
} else {
// Otherwise switch needs to be "crossed" to get j to upper half
switchValues[level][blockStart + switchNum] = true;
j ^= 1;
jSwitched = j ^ 1;
}
int jConnectedTo = getInputIdxInNextLayer(logN - level, 0, j);
upperPermIn[jConnectedTo] = j;
int jConnectedTo = getInputIdxInNextLayer(logN - level, 0, jSwitched);
// must be in upper half
assert(jConnectedTo < (1 << (logN - 1)));
upperPermIn[jConnectedTo] = permOut[i];
lowerPermIn[i] = permIn[i + 1];
// Connect from j's pair to through the lower half
int jPairConnectedTo = getInputIdxInNextLayer(logN - level, 0, jSwitched ^ 1);
lowerPermIn[jPairConnectedTo - Nover2] = permIn[j ^ 1];
int iPairNext = find(permIn[j ^ 1], permOut);
i = iPairNext ^ 1;
} while (unmatchedIndices.contains(i));
}
setInternalPermutation(upperPermIn, upperPermOut, idx, level + 1);
setInternalPermutation(lowerPermIn, lowerPermOut, idx, level + 1);
setInternalPermutation(upperPermIn, upperPermOut, 2*idx, level + 1);
setInternalPermutation(lowerPermIn, lowerPermOut, 2*idx + 1, level + 1);
}
@Override
public void setPermutation(int[] permutation) {
// We use the recursive algorithm attributed to ? in ?'s paper
int level = 0;
int[] identity = new int[getNumInputs()];
for (int i = 0; i < identity.length; ++i)
identity[i] = i;
setInternalPermutation(identity, permutation, 0, 0);
}
@Override
public <T> void permuteLayer(T[] values, int layer) {
public boolean isCrossed(int layer, int switchIdx) {
return switchValues[layer][switchIdx];
}
@Override
public <T> void permuteLayer(ArrayList<T> values, int layer) {
}
}

View File

@ -18,8 +18,26 @@ public interface PermutationNetwork {
*/
int getNumLayers();
/**
* Return the index of the switch output in the previous layer that connects to a specified input in the current layer.
* Each consecutive pair of outputs is from the same switch, and each consecutive pair of inputs goes to the same switch.
* (that is, switch j has inputs/outputs 2j and 2j+1)
*
* @param layer current layer index. Must be in the range [0, numLayers) (layer 0's previous layer consists of the inputs themselves)
* @param inputIdx the input Idx for the current layer. Must be in the range [0, numInputs)
*
* @return the requested index
*/
int getOutputIdxInPreviousLayer(int layer, int inputIdx);
/**
* The inverse of {@link }#getOutputIdxInPreviousLayer(int,int)}.
* @param layer
* @param inputIdx
* @return
*/
int getInputIdxInNextLayer(int layer, int inputIdx);
/**
* Initialize switches to generate a specific permutation.
* @param permutation
@ -28,12 +46,10 @@ public interface PermutationNetwork {
/**
* Apply a single layer's permutation (as implied by {@link #setPermutation(int[])})
* @param values values output by previous layer
* @param layer layer to apply (layer is between 1 and numLayers; layer 0 is the input layer)
* @param <T> type of values.
* Returns true iff the switch is crossed
* @param layer
* @return
*/
<T> void permuteLayer(T[] values, int layer);
boolean isCrossed(int layer, int switchIdx);
<T> void permuteLayer(ArrayList<T> values, int layer);
}

View File

@ -7,7 +7,7 @@ import java.util.Random;
/**
* Created by Tzlil on 12/17/2015.
* container for random permutation
* the permutation is sated in constructor and can't be change
* the permutation is set in the constructor and can't be changed
*/
public class RandomPermutation {
public final int[] permutation;
@ -43,4 +43,19 @@ public class RandomPermutation {
return result;
}
/**
* randomly shuffle an array of elements in-place (using Fisher-Yates)
* @param perm array of elements to shuffle.
* @param random
*/
public static <T> void permute(T[] perm, Random random){
for (int i = perm.length - 1; i > 0; --i) {
int j = random.nextInt(i+1);
T tmp = perm[i];
perm[i] = perm[j];
perm[j] = tmp;
}
}
}

View File

@ -0,0 +1,45 @@
package meerkat.mixer.mixing;
import java.util.ArrayList;
/**
* Generic utilities
*/
public class Util {
/*
* Apply a single layer's switch permutations in place (as implied by {@link PermutationNetwork#setPermutation(int[])})
* @param values values output by previous layxer
* @param layer layer to apply (layer is between 1 and numLayers; layer 0 is the input layer)
* @param <T> type of values.
*/
public static <T> void applyLayerSwitches(PermutationNetwork net, T[] values, int layer) {
int numSwitches = net.getNumInputs() >>> 1;
for (int i = 0; i < numSwitches; ++i) {
if (net.isCrossed(layer, i)) {
T tmp = values[i * 2];
values[i * 2] = values[i * 2 + 1];
values[i * 2 + 1] = tmp;
}
}
}
public static <T> void convertOutputPermToInputPerm(PermutationNetwork net, T[] oldValues, T[] newValues, int layer) {
int numInputs = net.getNumInputs();
assert(numInputs == oldValues.length);
assert(numInputs == newValues.length);
for (int i = 0; i < numInputs; ++i) {
newValues[i] = oldValues[net.getOutputIdxInPreviousLayer(layer, i)];
}
}
public static <T> void permute(PermutationNetwork net, T[] oldValues, T[] newValues) {
for (int layer = 0; layer < net.getNumLayers(); ++layer) {
convertOutputPermToInputPerm(net, oldValues, newValues, layer);
applyLayerSwitches(net, newValues, layer);
System.arraycopy(newValues, 0, oldValues, 0, oldValues.length);
}
}
}

View File

@ -45,7 +45,7 @@ public class MixingTest extends ECParamTestBase {
mixer = new Mixer(prover, enc);
// generate n
int logN = 8; // + random.nextInt(8)
int logN = 9; // + random.nextInt(8)
layers = 2*logN - 1;
n = 1 << logN;
}

View File

@ -1,15 +1,20 @@
package meerkat.mixer.mixing;
import org.junit.Test;
import java.util.Set;
import java.util.TreeSet;
import static org.junit.Assert.*;
/**
* Tests for Benes Network topology
*/
public class BenesNetworkTest {
public static class Permutation {
public class BenesNetworkTest extends PermutationNetworkTest {
final static int logN = 3;
@Override
protected PermutationNetwork getNewNetwork() {
return new BenesNetwork(logN);
}
public static isAPerm
}

View File

@ -0,0 +1,109 @@
package meerkat.mixer.mixing;
import org.junit.Before;
import org.junit.Test;
import java.util.Random;
import java.util.Set;
import java.util.TreeSet;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/**
* Generically Test a permutation network
*/
abstract public class PermutationNetworkTest {
public static final int NUM_REPS = 10;
Random rand;
abstract protected PermutationNetwork getNewNetwork();
PermutationNetwork network;
@Before
public void setup() {
network = getNewNetwork();
rand = new Random(1);
}
public static Set<Integer> getSequenceSet(int N) {
Set<Integer> set = new TreeSet<>();
for (int i = 0; i < N; ++i) {
set.add(i);
}
return set;
}
public static Integer[] getSequenceArray(int N) {
Integer[] arr = new Integer[N];
for (int i = 0; i < N; ++i) {
arr[i] = i;
}
return arr;
}
/**
* Check if a given network is actually a permutation network (i.e., always
* implies a permutation regardless of 2x2 switch settings).
*/
@Test
public void isAlwaysAPermutation() {
int numLayers = network.getNumLayers();
int N = network.getNumInputs();
for (int layer = 1; layer < numLayers; ++layer) {
Set<Integer> unusedInputs = getSequenceSet(N);
for (int i = 0; i < N; ++i) {
unusedInputs.remove(network.getOutputIdxInPreviousLayer(layer, i));
}
assertTrue("Not a permutation! Didn't use: " + unusedInputs, unusedInputs.isEmpty());
}
}
@Test
public void forwardEqualsBackwards() {
int numLayers = network.getNumLayers();
int N = network.getNumInputs();
for (int layer = 1; layer < numLayers; ++layer) {
for (int i = 0; i < N; ++i) {
int j = network.getOutputIdxInPreviousLayer(layer, i);
assertEquals(String.format("Input %d in layer %d has problems", i, layer), i,
network.getInputIdxInNextLayer(layer - 1, j));
}
}
}
public static int[] convert(Integer[] in) {
int[] out = new int[in.length];
for (int i = 0; i < in.length; ++i)
out[i] = in[i];
return out;
}
@Test
public void testRandomPermutations() throws Exception {
for (int rep = 0; rep < NUM_REPS; ++rep) {
Integer[] target = getSequenceArray(network.getNumInputs());
RandomPermutation.permute(target, rand);
network.setPermutation(convert(target));
Integer[] id = getSequenceArray(network.getNumInputs());
Integer[] out = new Integer[target.length];
Util.permute(network, id, out);
assertArrayEquals("Permutation mismatch: " + target + " != " + out, target, out);
}
}
}