Skip to content

Commit 39f89dc

Browse files
committed
embryo majority layer
1 parent f247891 commit 39f89dc

File tree

3 files changed

+157
-0
lines changed

3 files changed

+157
-0
lines changed

neurallogic/hard_majority.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import jax
2+
3+
4+
5+
def majority_index(input_size: int) -> int:
6+
return (input_size - 1) // 2
7+
8+
9+
def soft_majority(x: jax.numpy.array) -> float:
10+
index = majority_index(x.shape[-1])
11+
sorted_x = jax.numpy.sort(x, axis=-1)
12+
return jax.numpy.take(sorted_x, index, axis=-1)
13+
14+
15+
def hard_majority(x: jax.numpy.array) -> bool:
16+
threshold = x.shape[-1] - majority_index(x.shape[-1])
17+
return jax.numpy.sum(x, axis=-1) >= threshold
18+
19+
20+
soft_majority_layer = jax.vmap(soft_majority, in_axes=0)
21+
22+
hard_majority_layer = jax.vmap(hard_majority, in_axes=0)

neurallogic/harden.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,27 @@ def harden(x: float):
2020
return harden_float(x)
2121

2222

23+
@dispatch
24+
def harden(x: bool):
25+
return x
26+
27+
2328
@dispatch
2429
def harden(x: list):
2530
return map_at_elements.map_at_elements(x, harden_float)
2631

2732

2833
@dispatch
2934
def harden(x: numpy.ndarray):
35+
if x.ndim == 0:
36+
return harden(x.item())
3037
return harden_array(x)
3138

3239

3340
@dispatch
3441
def harden(x: jax.numpy.ndarray):
42+
if x.ndim == 0:
43+
return harden(x.item())
3544
return harden_array(x)
3645

3746

tests/test_hard_majority.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import numpy
2+
import jax
3+
4+
from neurallogic import hard_majority, harden
5+
6+
7+
def test_majority_index():
8+
assert hard_majority.majority_index(1) == 0
9+
assert hard_majority.majority_index(2) == 0
10+
assert hard_majority.majority_index(3) == 1
11+
assert hard_majority.majority_index(4) == 1
12+
assert hard_majority.majority_index(5) == 2
13+
assert hard_majority.majority_index(6) == 2
14+
assert hard_majority.majority_index(7) == 3
15+
assert hard_majority.majority_index(8) == 3
16+
assert hard_majority.majority_index(9) == 4
17+
assert hard_majority.majority_index(10) == 4
18+
assert hard_majority.majority_index(11) == 5
19+
assert hard_majority.majority_index(12) == 5
20+
21+
22+
def test_soft_majority():
23+
assert hard_majority.soft_majority(numpy.array([1.0])) == 1.0
24+
assert hard_majority.soft_majority(numpy.array([2.0, 1.0])) == 1.0
25+
assert hard_majority.soft_majority(numpy.array([1.0, 3.0, 2.0])) == 2.0
26+
assert hard_majority.soft_majority(
27+
numpy.array([2.0, 1.0, 4.0, 3.0])) == 2.0
28+
assert hard_majority.soft_majority(
29+
numpy.array([1.0, 2.0, 3.0, 4.0, 5.0])) == 3.0
30+
assert hard_majority.soft_majority(
31+
numpy.array([6.0, 3.0, 2.0, 4.0, 5.0, 1.0])) == 3.0
32+
assert hard_majority.soft_majority(numpy.array(
33+
[7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0])) == 4.0
34+
assert hard_majority.soft_majority(numpy.array(
35+
[2.0, 1.0, 4.0, 3.0, 6.0, 5.0, 8.0, 7.0])) == 4.0
36+
assert hard_majority.soft_majority(numpy.array(
37+
[1.0, 2.0, 3.0, 5.0, 4.0, 6.0, 7.0, 9.0, 8.0])) == 5.0
38+
assert hard_majority.soft_majority(numpy.array(
39+
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])) == 5.0
40+
assert hard_majority.soft_majority(numpy.array(
41+
[11.0, 10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0])) == 6.0
42+
assert hard_majority.soft_majority(numpy.array(
43+
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0])) == 6.0
44+
45+
46+
def test_hard_majority():
47+
assert hard_majority.hard_majority(numpy.array([True])) == True
48+
assert hard_majority.hard_majority(numpy.array([False])) == False
49+
assert hard_majority.hard_majority(numpy.array([True, False])) == False
50+
assert hard_majority.hard_majority(
51+
numpy.array([False, True, False])) == False
52+
assert hard_majority.hard_majority(
53+
numpy.array([True, False, True, False])) == False
54+
assert hard_majority.hard_majority(numpy.array(
55+
[False, True, False, True, False])) == False
56+
assert hard_majority.hard_majority(numpy.array(
57+
[True, True, True, False, True, False])) == True
58+
assert hard_majority.hard_majority(numpy.array(
59+
[True, False, False, True, True, True, False])) == True
60+
assert hard_majority.hard_majority(numpy.array(
61+
[False, True, False, True, False, True, False, True])) == False
62+
assert hard_majority.hard_majority(numpy.array(
63+
[True, True, True, True, True, False, True, True, True])) == True
64+
assert hard_majority.hard_majority(numpy.array(
65+
[True, False, False, False, False, False, True, True, True, True])) == False
66+
67+
68+
def test_soft_and_hard_majority_equivalence():
69+
soft_maj = jax.jit(hard_majority.soft_majority)
70+
hard_maj = jax.jit(hard_majority.hard_majority)
71+
for i in range(1, 100):
72+
input = numpy.random.rand(i)
73+
soft_output = soft_maj(input)
74+
hard_output = hard_maj(harden.harden(input))
75+
assert harden.harden(soft_output) == hard_output
76+
77+
78+
def test_soft_majority_layer():
79+
assert numpy.all(hard_majority.soft_majority_layer(
80+
numpy.array([[2.0, 1.0], [1.0, 2.0]])) == numpy.array([1.0, 1.0]))
81+
assert numpy.all(hard_majority.soft_majority_layer(numpy.array(
82+
[[1.0, 2.0, 3.0], [3.0, 2.0, 1.0]])) == numpy.array([2.0, 2.0]))
83+
assert numpy.all(hard_majority.soft_majority_layer(numpy.array(
84+
[[1.0, 2.0, 3.0, 4.0], [4.0, 3.0, 2.0, 1.0]])) == numpy.array([2.0, 2.0]))
85+
assert numpy.all(hard_majority.soft_majority_layer(numpy.array(
86+
[[1.0, 2.0, 3.0, 4.0, 5.0], [5.0, 4.0, 3.0, 2.0, 1.0]])) == numpy.array([3.0, 3.0]))
87+
assert numpy.all(hard_majority.soft_majority_layer(numpy.array(
88+
[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [6.0, 5.0, 4.0, 3.0, 2.0, 1.0]])) == numpy.array([3.0, 3.0]))
89+
assert numpy.all(hard_majority.soft_majority_layer(numpy.array(
90+
[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], [7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0]])) == numpy.array([4.0, 4.0]))
91+
assert numpy.all(hard_majority.soft_majority_layer(numpy.array(
92+
[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], [8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0]])) == numpy.array([4.0, 4.0]))
93+
assert numpy.all(hard_majority.soft_majority_layer(numpy.array(
94+
[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], [9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0]])) == numpy.array([5.0, 5.0]))
95+
assert numpy.all(hard_majority.soft_majority_layer(numpy.array(
96+
[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], [10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0]])) == numpy.array([5.0, 5.0]))
97+
assert numpy.all(hard_majority.soft_majority_layer(numpy.array(
98+
[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0], [11.0, 10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0]])) == numpy.array([6.0, 6.0]))
99+
assert numpy.all(hard_majority.soft_majority_layer(numpy.array(
100+
[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], [12.0, 11.0, 10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0]])) == numpy.array([6.0, 6.0]))
101+
102+
103+
def test_hard_majority_layer():
104+
assert numpy.all(hard_majority.hard_majority_layer(numpy.array(
105+
[[True, False], [False, True]])) == numpy.array([False, False]))
106+
assert numpy.all(hard_majority.hard_majority_layer(numpy.array(
107+
[[True, False, True], [True, False, True]])) == numpy.array([True, True]))
108+
assert numpy.all(hard_majority.hard_majority_layer(numpy.array(
109+
[[True, False, True, False], [False, True, False, True]])) == numpy.array([False, False]))
110+
assert numpy.all(hard_majority.hard_majority_layer(numpy.array(
111+
[[True, False, True, False, True], [True, False, True, False, True]])) == numpy.array([True, True]))
112+
assert numpy.all(hard_majority.hard_majority_layer(numpy.array([[True, False, True, False, True, False], [
113+
False, True, False, True, False, True]])) == numpy.array([False, False]))
114+
assert numpy.all(hard_majority.hard_majority_layer(numpy.array([[True, False, True, False, True, False, True], [
115+
True, False, True, False, True, False, False]])) == numpy.array([True, False]))
116+
117+
assert numpy.all(hard_majority.hard_majority_layer(numpy.array([[True, False], [
118+
False, True], [False, True]])) == numpy.array([False, False, False]))
119+
assert numpy.all(hard_majority.hard_majority_layer(numpy.array([[True, False, True], [
120+
True, False, True], [True, False, True]])) == numpy.array([True, True, True]))
121+
assert numpy.all(hard_majority.hard_majority_layer(numpy.array([[True, False, True, False], [
122+
False, True, False, True], [False, True, False, True]])) == numpy.array([False, False, False]))
123+
assert numpy.all(hard_majority.hard_majority_layer(numpy.array([[True, False, True, False, True], [
124+
True, False, True, False, True], [True, False, True, False, True]])) == numpy.array([True, True, True]))
125+
assert numpy.all(hard_majority.hard_majority_layer(numpy.array([[True, False, True, False, True, False], [
126+
False, True, False, True, False, True], [False, True, False, True, False, True]])) == numpy.array([False, False, False]))

0 commit comments

Comments
 (0)