Working on Benes network/testing
parent
abf4cc5e54
commit
37d1857f9c
|
@ -39,25 +39,22 @@ public class BenesNetwork implements PermutationNetwork
|
|||
* @return the requested index
|
||||
*/
|
||||
public static int getOutputIdxInPreviousLayer(int logN, int layer, int inputIdx) {
|
||||
assert (layer > 0) && (layer < 2*logN - 1);
|
||||
assert (layer >= 0) && (layer < 2*logN - 1);
|
||||
assert (inputIdx >= 0) && (inputIdx < 1 << inputIdx);
|
||||
|
||||
if ((inputIdx & 1) == 0) {
|
||||
// Even inputs are connected straight "across" everywhere
|
||||
if ((layer == 0) || (inputIdx & 1) == 0) {
|
||||
// layer 0 inputs and all even inputs
|
||||
// are connected straight "across" everywhere
|
||||
return inputIdx;
|
||||
}
|
||||
|
||||
// --- Odd inputs are connected depending on layer ---
|
||||
|
||||
// In middle layer everything goes across
|
||||
if (layer == logN)
|
||||
return inputIdx;
|
||||
|
||||
int crossBit;
|
||||
if (layer < logN)
|
||||
crossBit = logN - layer;
|
||||
else
|
||||
crossBit = layer - logN;
|
||||
crossBit = layer - logN + 1;
|
||||
|
||||
return inputIdx ^ (1 << crossBit);
|
||||
}
|
||||
|
@ -73,6 +70,11 @@ public class BenesNetwork implements PermutationNetwork
|
|||
return getOutputIdxInPreviousLayer(logN, layer + 1, outputIdx);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getInputIdxInNextLayer(int layer, int outputIdx) {
|
||||
return getInputIdxInNextLayer(logN, layer, outputIdx);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getNumInputs() {
|
||||
return 1 << logN;
|
||||
|
@ -111,11 +113,13 @@ public class BenesNetwork implements PermutationNetwork
|
|||
* @param level
|
||||
*/
|
||||
private void setInternalPermutation(int[] permIn, int[] permOut, int idx, int level) {
|
||||
// Make sure idx is in the the range [0, 2^{logN-level-1})
|
||||
assert (idx == (idx & ((1 << (logN - level - 1)) - 1)));
|
||||
// Make sure idx is in the the range [0, 2^{level})
|
||||
assert (idx < (1 << level));
|
||||
assert (permOut.length == (1 << (logN - level)));
|
||||
assert (permIn.length == permOut.length);
|
||||
|
||||
final int Nover2 = 1 << (logN - level - 1);
|
||||
|
||||
// Base case
|
||||
if (level == logN - 1) {
|
||||
if (permOut[0] == permIn[0])
|
||||
|
@ -150,65 +154,80 @@ public class BenesNetwork implements PermutationNetwork
|
|||
do {
|
||||
//assert ((i & 1) == 0); // Algorithm guarantees only even indices are unmatched at start of loop
|
||||
unmatchedIndices.remove(i);
|
||||
unmatchedIndices.remove(i ^ 1);
|
||||
|
||||
int switchNum = i >>> 1;
|
||||
|
||||
// index of outPerm[i] after the last switch layer.
|
||||
int iSwitched;
|
||||
|
||||
if (((i & 1) == 0) && (switchNum < numSwitches / 2) || ((i & 1) == 1) && (switchNum >= numSwitches / 2)) {
|
||||
// i is in the upper half and even, or in the lower half and odd, so
|
||||
// switch must be "straight" to get i to upper half.
|
||||
switchValues[2 * logN - level - 1][blockStart + switchNum] = false;
|
||||
switchValues[2 * logN - level - 2][blockStart + switchNum] = false;
|
||||
iSwitched = i;
|
||||
|
||||
} else {
|
||||
switchValues[2 * logN - level - 1][blockStart + switchNum] = true;
|
||||
switchValues[2 * logN - level - 2][blockStart + switchNum] = true;
|
||||
iSwitched = i ^ 1;
|
||||
}
|
||||
|
||||
int iConnectedTo = getOutputIdxInPreviousLayer(logN - level, 2 * (logN - level), i);
|
||||
int iPairConnectedTo = getOutputIdxInPreviousLayer(logN - level, 2 * (logN - level), i ^ 1);
|
||||
// Index of outPerm[i] in the output of the previous layer.
|
||||
int iConnectedTo = getOutputIdxInPreviousLayer(logN - level, 2 * (logN - level) - 2, iSwitched);
|
||||
int iPairConnectedTo = getOutputIdxInPreviousLayer(logN - level, 2 * (logN - level) - 2, iSwitched ^ 1);
|
||||
|
||||
upperPermOut[iConnectedTo] = permOut[i];
|
||||
lowerPermOut[iPairConnectedTo] = permOut[i ^ 1];
|
||||
lowerPermOut[iPairConnectedTo - Nover2] = permOut[i ^ 1];
|
||||
|
||||
// Index of permOut[i] before the first switch layer.
|
||||
int j = find(permOut[i], permIn);
|
||||
int jSwitchNum = j >>> 1;
|
||||
int j
|
||||
|
||||
// Index of permOut[i] after the first switch layer
|
||||
int jSwitched;
|
||||
|
||||
if ((((j & 1) == 0) && (jSwitchNum < numSwitches / 2)) || ((j & 1) == 1) && (jSwitchNum >= numSwitches / 2)) {
|
||||
// Even output in the upper half, or odd output in the lower half
|
||||
// so switch needs to be "straight" to get j to the upper half
|
||||
switchValues[level][blockStart + switchNum] = false;
|
||||
jSwitched = j;
|
||||
} else {
|
||||
// Otherwise switch needs to be "crossed" to get j to upper half
|
||||
switchValues[level][blockStart + switchNum] = true;
|
||||
j ^= 1;
|
||||
jSwitched = j ^ 1;
|
||||
}
|
||||
|
||||
int jConnectedTo = getInputIdxInNextLayer(logN - level, 0, j);
|
||||
upperPermIn[jConnectedTo] = j;
|
||||
int jConnectedTo = getInputIdxInNextLayer(logN - level, 0, jSwitched);
|
||||
// must be in upper half
|
||||
assert(jConnectedTo < (1 << (logN - 1)));
|
||||
upperPermIn[jConnectedTo] = permOut[i];
|
||||
|
||||
lowerPermIn[i] = permIn[i + 1];
|
||||
// Connect from j's pair to through the lower half
|
||||
int jPairConnectedTo = getInputIdxInNextLayer(logN - level, 0, jSwitched ^ 1);
|
||||
lowerPermIn[jPairConnectedTo - Nover2] = permIn[j ^ 1];
|
||||
|
||||
int iPairNext = find(permIn[j ^ 1], permOut);
|
||||
i = iPairNext ^ 1;
|
||||
} while (unmatchedIndices.contains(i));
|
||||
}
|
||||
|
||||
setInternalPermutation(upperPermIn, upperPermOut, idx, level + 1);
|
||||
setInternalPermutation(lowerPermIn, lowerPermOut, idx, level + 1);
|
||||
setInternalPermutation(upperPermIn, upperPermOut, 2*idx, level + 1);
|
||||
setInternalPermutation(lowerPermIn, lowerPermOut, 2*idx + 1, level + 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setPermutation(int[] permutation) {
|
||||
|
||||
// We use the recursive algorithm attributed to ? in ?'s paper
|
||||
int level = 0;
|
||||
int[] identity = new int[getNumInputs()];
|
||||
for (int i = 0; i < identity.length; ++i)
|
||||
identity[i] = i;
|
||||
|
||||
setInternalPermutation(identity, permutation, 0, 0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> void permuteLayer(T[] values, int layer) {
|
||||
|
||||
public boolean isCrossed(int layer, int switchIdx) {
|
||||
return switchValues[layer][switchIdx];
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> void permuteLayer(ArrayList<T> values, int layer) {
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,8 +18,26 @@ public interface PermutationNetwork {
|
|||
*/
|
||||
int getNumLayers();
|
||||
|
||||
/**
|
||||
* Return the index of the switch output in the previous layer that connects to a specified input in the current layer.
|
||||
* Each consecutive pair of outputs is from the same switch, and each consecutive pair of inputs goes to the same switch.
|
||||
* (that is, switch j has inputs/outputs 2j and 2j+1)
|
||||
*
|
||||
* @param layer current layer index. Must be in the range [0, numLayers) (layer 0's previous layer consists of the inputs themselves)
|
||||
* @param inputIdx the input Idx for the current layer. Must be in the range [0, numInputs)
|
||||
*
|
||||
* @return the requested index
|
||||
*/
|
||||
int getOutputIdxInPreviousLayer(int layer, int inputIdx);
|
||||
|
||||
/**
|
||||
* The inverse of {@link }#getOutputIdxInPreviousLayer(int,int)}.
|
||||
* @param layer
|
||||
* @param inputIdx
|
||||
* @return
|
||||
*/
|
||||
int getInputIdxInNextLayer(int layer, int inputIdx);
|
||||
|
||||
/**
|
||||
* Initialize switches to generate a specific permutation.
|
||||
* @param permutation
|
||||
|
@ -28,12 +46,10 @@ public interface PermutationNetwork {
|
|||
|
||||
|
||||
/**
|
||||
* Apply a single layer's permutation (as implied by {@link #setPermutation(int[])})
|
||||
* @param values values output by previous layer
|
||||
* @param layer layer to apply (layer is between 1 and numLayers; layer 0 is the input layer)
|
||||
* @param <T> type of values.
|
||||
* Returns true iff the switch is crossed
|
||||
* @param layer
|
||||
* @return
|
||||
*/
|
||||
<T> void permuteLayer(T[] values, int layer);
|
||||
boolean isCrossed(int layer, int switchIdx);
|
||||
|
||||
<T> void permuteLayer(ArrayList<T> values, int layer);
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ import java.util.Random;
|
|||
/**
|
||||
* Created by Tzlil on 12/17/2015.
|
||||
* container for random permutation
|
||||
* the permutation is sated in constructor and can't be change
|
||||
* the permutation is set in the constructor and can't be changed
|
||||
*/
|
||||
public class RandomPermutation {
|
||||
public final int[] permutation;
|
||||
|
@ -43,4 +43,19 @@ public class RandomPermutation {
|
|||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* randomly shuffle an array of elements in-place (using Fisher-Yates)
|
||||
* @param perm array of elements to shuffle.
|
||||
* @param random
|
||||
*/
|
||||
public static <T> void permute(T[] perm, Random random){
|
||||
for (int i = perm.length - 1; i > 0; --i) {
|
||||
int j = random.nextInt(i+1);
|
||||
T tmp = perm[i];
|
||||
perm[i] = perm[j];
|
||||
perm[j] = tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
package meerkat.mixer.mixing;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
||||
/**
|
||||
* Generic utilities
|
||||
*/
|
||||
public class Util {
|
||||
/*
|
||||
* Apply a single layer's switch permutations in place (as implied by {@link PermutationNetwork#setPermutation(int[])})
|
||||
* @param values values output by previous layxer
|
||||
* @param layer layer to apply (layer is between 1 and numLayers; layer 0 is the input layer)
|
||||
* @param <T> type of values.
|
||||
*/
|
||||
public static <T> void applyLayerSwitches(PermutationNetwork net, T[] values, int layer) {
|
||||
int numSwitches = net.getNumInputs() >>> 1;
|
||||
for (int i = 0; i < numSwitches; ++i) {
|
||||
if (net.isCrossed(layer, i)) {
|
||||
T tmp = values[i * 2];
|
||||
values[i * 2] = values[i * 2 + 1];
|
||||
values[i * 2 + 1] = tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public static <T> void convertOutputPermToInputPerm(PermutationNetwork net, T[] oldValues, T[] newValues, int layer) {
|
||||
int numInputs = net.getNumInputs();
|
||||
assert(numInputs == oldValues.length);
|
||||
assert(numInputs == newValues.length);
|
||||
|
||||
for (int i = 0; i < numInputs; ++i) {
|
||||
newValues[i] = oldValues[net.getOutputIdxInPreviousLayer(layer, i)];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public static <T> void permute(PermutationNetwork net, T[] oldValues, T[] newValues) {
|
||||
for (int layer = 0; layer < net.getNumLayers(); ++layer) {
|
||||
convertOutputPermToInputPerm(net, oldValues, newValues, layer);
|
||||
applyLayerSwitches(net, newValues, layer);
|
||||
System.arraycopy(newValues, 0, oldValues, 0, oldValues.length);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -45,7 +45,7 @@ public class MixingTest extends ECParamTestBase {
|
|||
mixer = new Mixer(prover, enc);
|
||||
|
||||
// generate n
|
||||
int logN = 8; // + random.nextInt(8)
|
||||
int logN = 9; // + random.nextInt(8)
|
||||
layers = 2*logN - 1;
|
||||
n = 1 << logN;
|
||||
}
|
||||
|
|
|
@ -1,15 +1,20 @@
|
|||
package meerkat.mixer.mixing;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.Set;
|
||||
import java.util.TreeSet;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
* Tests for Benes Network topology
|
||||
*/
|
||||
public class BenesNetworkTest {
|
||||
public static class Permutation {
|
||||
public class BenesNetworkTest extends PermutationNetworkTest {
|
||||
final static int logN = 3;
|
||||
|
||||
@Override
|
||||
protected PermutationNetwork getNewNetwork() {
|
||||
return new BenesNetwork(logN);
|
||||
}
|
||||
|
||||
public static isAPerm
|
||||
|
||||
}
|
|
@ -0,0 +1,109 @@
|
|||
package meerkat.mixer.mixing;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.Random;
|
||||
import java.util.Set;
|
||||
import java.util.TreeSet;
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
/**
|
||||
* Generically Test a permutation network
|
||||
*/
|
||||
abstract public class PermutationNetworkTest {
|
||||
|
||||
public static final int NUM_REPS = 10;
|
||||
Random rand;
|
||||
|
||||
abstract protected PermutationNetwork getNewNetwork();
|
||||
|
||||
PermutationNetwork network;
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
network = getNewNetwork();
|
||||
rand = new Random(1);
|
||||
}
|
||||
|
||||
public static Set<Integer> getSequenceSet(int N) {
|
||||
Set<Integer> set = new TreeSet<>();
|
||||
|
||||
for (int i = 0; i < N; ++i) {
|
||||
set.add(i);
|
||||
}
|
||||
return set;
|
||||
}
|
||||
|
||||
public static Integer[] getSequenceArray(int N) {
|
||||
Integer[] arr = new Integer[N];
|
||||
|
||||
for (int i = 0; i < N; ++i) {
|
||||
arr[i] = i;
|
||||
}
|
||||
return arr;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Check if a given network is actually a permutation network (i.e., always
|
||||
* implies a permutation regardless of 2x2 switch settings).
|
||||
*/
|
||||
@Test
|
||||
public void isAlwaysAPermutation() {
|
||||
int numLayers = network.getNumLayers();
|
||||
int N = network.getNumInputs();
|
||||
|
||||
for (int layer = 1; layer < numLayers; ++layer) {
|
||||
Set<Integer> unusedInputs = getSequenceSet(N);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
unusedInputs.remove(network.getOutputIdxInPreviousLayer(layer, i));
|
||||
}
|
||||
assertTrue("Not a permutation! Didn't use: " + unusedInputs, unusedInputs.isEmpty());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void forwardEqualsBackwards() {
|
||||
int numLayers = network.getNumLayers();
|
||||
int N = network.getNumInputs();
|
||||
|
||||
for (int layer = 1; layer < numLayers; ++layer) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
int j = network.getOutputIdxInPreviousLayer(layer, i);
|
||||
assertEquals(String.format("Input %d in layer %d has problems", i, layer), i,
|
||||
network.getInputIdxInNextLayer(layer - 1, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static int[] convert(Integer[] in) {
|
||||
int[] out = new int[in.length];
|
||||
for (int i = 0; i < in.length; ++i)
|
||||
out[i] = in[i];
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRandomPermutations() throws Exception {
|
||||
for (int rep = 0; rep < NUM_REPS; ++rep) {
|
||||
Integer[] target = getSequenceArray(network.getNumInputs());
|
||||
RandomPermutation.permute(target, rand);
|
||||
|
||||
network.setPermutation(convert(target));
|
||||
|
||||
Integer[] id = getSequenceArray(network.getNumInputs());
|
||||
|
||||
Integer[] out = new Integer[target.length];
|
||||
|
||||
Util.permute(network, id, out);
|
||||
|
||||
assertArrayEquals("Permutation mismatch: " + target + " != " + out, target, out);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue