Skip to content

Commit a0f7af7

Browse files
committed
rename and refactor the learnable masks
1 parent 7a1910b commit a0f7af7

File tree

9 files changed

+325
-274
lines changed

9 files changed

+325
-274
lines changed

neurallogic/hard_and.py

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,19 @@
1-
from typing import Any
1+
from typing import Callable
22

33
import jax
44
from flax import linen as nn
5-
from typing import Callable
6-
7-
8-
from neurallogic import neural_logic_net, symbolic_generation
9-
10-
11-
def soft_and_include(w: float, x: float) -> float:
12-
"""
13-
w > 0.5 implies the and operation is active, else inactive
14-
15-
Assumes x is in [0, 1]
16-
17-
Corresponding hard logic: x OR ! w
18-
"""
19-
w = jax.numpy.clip(w, 0.0, 1.0)
20-
return jax.numpy.maximum(x, 1.0 - w)
21-
22-
23-
24-
def hard_and_include(w, x):
25-
return jax.numpy.logical_or(x, jax.numpy.logical_not(w))
265

6+
from neurallogic import hard_masks, neural_logic_net, symbolic_generation
277

288

9+
# TODO: seperate and operation from mask operation
2910
def soft_and_neuron(w, x):
30-
x = jax.vmap(soft_and_include, 0, 0)(w, x)
11+
x = jax.vmap(hard_masks.soft_mask_to_true, 0, 0)(w, x)
3112
return jax.numpy.min(x)
3213

3314

3415
def hard_and_neuron(w, x):
35-
x = jax.vmap(hard_and_include, 0, 0)(w, x)
16+
x = jax.vmap(hard_masks.hard_mask_to_true, 0, 0)(w, x)
3617
return jax.lax.reduce(x, True, jax.lax.bitwise_and, [0])
3718

3819

@@ -41,6 +22,7 @@ def hard_and_neuron(w, x):
4122
hard_and_layer = jax.vmap(hard_and_neuron, (0, None), 0)
4223

4324

25+
# TODO: move initialization to separate file
4426
def initialize_near_to_zero():
4527
# TODO: investigate better initialization
4628
def init(key, shape, dtype):
@@ -51,6 +33,7 @@ def init(key, shape, dtype):
5133
x = 0.5 * x - 1
5234
x = jax.numpy.clip(x, 0.001, 0.999)
5335
return x
36+
5437
return init
5538

5639

@@ -62,15 +45,17 @@ class SoftAndLayer(nn.Module):
6245
layer_size: The number of neurons in the layer.
6346
weights_init: The initializer function for the weight matrix.
6447
"""
48+
6549
layer_size: int
6650
weights_init: Callable = initialize_near_to_zero()
6751
dtype: jax.numpy.dtype = jax.numpy.float32
6852

6953
@nn.compact
7054
def __call__(self, x):
7155
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
72-
weights = self.param('bit_weights', self.weights_init,
73-
weights_shape, self.dtype)
56+
weights = self.param(
57+
"bit_weights", self.weights_init, weights_shape, self.dtype
58+
)
7459
x = jax.numpy.asarray(x, self.dtype)
7560
return soft_and_layer(weights, x)
7661

@@ -83,13 +68,15 @@ class HardAndLayer(nn.Module):
8368
Attributes:
8469
layer_size: The number of neurons in the layer.
8570
"""
71+
8672
layer_size: int
8773

8874
@nn.compact
8975
def __call__(self, x):
9076
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
9177
weights = self.param(
92-
'bit_weights', nn.initializers.constant(True), weights_shape)
78+
"bit_weights", nn.initializers.constant(True), weights_shape
79+
)
9380
return hard_and_layer(weights, x)
9481

9582

@@ -104,6 +91,13 @@ def __call__(self, x):
10491

10592

10693
and_layer = neural_logic_net.select(
107-
lambda layer_size, weights_init=initialize_near_to_zero(), dtype=jax.numpy.float32: SoftAndLayer(layer_size, weights_init, dtype),
108-
lambda layer_size, weights_init=nn.initializers.constant(True), dtype=jax.numpy.float32: HardAndLayer(layer_size),
109-
lambda layer_size, weights_init=nn.initializers.constant(True), dtype=jax.numpy.float32: SymbolicAndLayer(layer_size))
94+
lambda layer_size, weights_init=initialize_near_to_zero(), dtype=jax.numpy.float32: SoftAndLayer(
95+
layer_size, weights_init, dtype
96+
),
97+
lambda layer_size, weights_init=nn.initializers.constant(
98+
True
99+
), dtype=jax.numpy.float32: HardAndLayer(layer_size),
100+
lambda layer_size, weights_init=nn.initializers.constant(
101+
True
102+
), dtype=jax.numpy.float32: SymbolicAndLayer(layer_size),
103+
)

neurallogic/hard_masks.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import jax
2+
3+
4+
def soft_mask_to_true(w: float, x: float) -> float:
5+
"""
6+
w > 0.5 implies the and operation is active, else inactive
7+
8+
Assumes x is in [0, 1]
9+
10+
Corresponding hard logic: x OR ! w
11+
"""
12+
w = jax.numpy.clip(w, 0.0, 1.0)
13+
return jax.numpy.maximum(x, 1.0 - w)
14+
15+
16+
def hard_mask_to_true(w, x):
17+
return jax.numpy.logical_or(x, jax.numpy.logical_not(w))
18+
19+
20+
def soft_mask_to_false(w: float, x: float) -> float:
21+
"""
22+
w > 0.5 implies the and operation is active, else inactive
23+
24+
Assumes x is in [0, 1]
25+
26+
Corresponding hard logic: b AND w
27+
"""
28+
w = jax.numpy.clip(w, 0.0, 1.0)
29+
return 1.0 - jax.numpy.maximum(1.0 - x, 1.0 - w)
30+
31+
32+
def hard_mask_to_false(w, x):
33+
return jax.numpy.logical_and(x, w)

neurallogic/hard_or.py

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,32 +3,17 @@
33
import jax
44
from flax import linen as nn
55

6-
from neurallogic import neural_logic_net, symbolic_generation
7-
8-
9-
def soft_or_include(w: float, x: float) -> float:
10-
"""
11-
w > 0.5 implies the and operation is active, else inactive
12-
13-
Assumes x is in [0, 1]
14-
15-
Corresponding hard logic: b AND w
16-
"""
17-
w = jax.numpy.clip(w, 0.0, 1.0)
18-
return 1.0 - jax.numpy.maximum(1.0 - x, 1.0 - w)
19-
20-
21-
def hard_or_include(w, x):
22-
return jax.numpy.logical_and(x, w)
6+
from neurallogic import neural_logic_net, symbolic_generation, hard_masks
237

248

9+
# TODO: seperate out the or operation from the mask operation
2510
def soft_or_neuron(w, x):
26-
x = jax.vmap(soft_or_include, 0, 0)(w, x)
11+
x = jax.vmap(hard_masks.soft_mask_to_false, 0, 0)(w, x)
2712
return jax.numpy.max(x)
2813

2914

3015
def hard_or_neuron(w, x):
31-
x = jax.vmap(hard_or_include, 0, 0)(w, x)
16+
x = jax.vmap(hard_masks.hard_mask_to_false, 0, 0)(w, x)
3217
return jax.lax.reduce(x, False, jax.lax.bitwise_or, [0])
3318

3419

@@ -48,6 +33,7 @@ def init(key, shape, dtype):
4833
x = 0.5 * x + 1
4934
x = jax.numpy.clip(x, 0.001, 0.999)
5035
return x
36+
5137
return init
5238

5339

@@ -60,8 +46,9 @@ class SoftOrLayer(nn.Module):
6046
def __call__(self, x):
6147
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
6248
weights = self.param(
63-
'bit_weights', self.weights_init, weights_shape, self.dtype)
64-
x = jax.numpy.asarray(x, self.dtype)
49+
"bit_weights", self.weights_init, weights_shape, self.dtype
50+
)
51+
x = jax.numpy.asarray(x, self.dtype)
6552
return soft_or_layer(weights, x)
6653

6754

@@ -72,7 +59,8 @@ class HardOrLayer(nn.Module):
7259
def __call__(self, x):
7360
weights_shape = (self.layer_size, jax.numpy.shape(x)[-1])
7461
weights = self.param(
75-
'bit_weights', nn.initializers.constant(True), weights_shape)
62+
"bit_weights", nn.initializers.constant(True), weights_shape
63+
)
7664
return hard_or_layer(weights, x)
7765

7866

@@ -82,14 +70,18 @@ def __init__(self, layer_size):
8270
self.hard_or_layer = HardOrLayer(self.layer_size)
8371

8472
def __call__(self, x):
85-
jaxpr = symbolic_generation.make_symbolic_flax_jaxpr(
86-
self.hard_or_layer, x)
73+
jaxpr = symbolic_generation.make_symbolic_flax_jaxpr(self.hard_or_layer, x)
8774
return symbolic_generation.symbolic_expression(jaxpr, x)
8875

8976

9077
or_layer = neural_logic_net.select(
91-
lambda layer_size, weights_init=initialize_near_to_one(
92-
), dtype=jax.numpy.float32: SoftOrLayer(layer_size, weights_init, dtype),
78+
lambda layer_size, weights_init=initialize_near_to_one(), dtype=jax.numpy.float32: SoftOrLayer(
79+
layer_size, weights_init, dtype
80+
),
81+
lambda layer_size, weights_init=nn.initializers.constant(
82+
True
83+
), dtype=jax.numpy.float32: HardOrLayer(layer_size),
9384
lambda layer_size, weights_init=nn.initializers.constant(
94-
True), dtype=jax.numpy.float32: HardOrLayer(layer_size),
95-
lambda layer_size, weights_init=nn.initializers.constant(True), dtype=jax.numpy.float32: SymbolicOrLayer(layer_size))
85+
True
86+
), dtype=jax.numpy.float32: SymbolicOrLayer(layer_size),
87+
)

neurallogic/hard_xor.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,13 @@
33
import jax
44
from flax import linen as nn
55

6-
from neurallogic import neural_logic_net, symbolic_generation, hard_and
7-
8-
9-
def soft_xor_include(w: float, x: float) -> float:
10-
"""
11-
w > 0.5 implies the and operation is active, else inactive
12-
13-
Assumes x is in [0, 1]
14-
15-
Corresponding hard logic: b AND w
16-
"""
17-
w = jax.numpy.clip(w, 0.0, 1.0)
18-
return 1.0 - jax.numpy.maximum(1.0 - x, 1.0 - w)
19-
20-
21-
def hard_xor_include(w, x):
22-
return jax.numpy.logical_and(x, w)
6+
from neurallogic import neural_logic_net, symbolic_generation, hard_masks
237

248

9+
# TODO: seperate out the mask from the xor operation
2510
def soft_xor_neuron(w, x):
2611
# Conditionally include input bits, according to weights
27-
x = jax.vmap(soft_xor_include, 0, 0)(w, x)
12+
x = jax.vmap(hard_masks.soft_mask_to_false, 0, 0)(w, x)
2813

2914
def xor(x, y):
3015
return jax.numpy.minimum(jax.numpy.maximum(x, y), 1.0 - jax.numpy.minimum(x, y))
@@ -34,7 +19,7 @@ def xor(x, y):
3419

3520

3621
def hard_xor_neuron(w, x):
37-
x = jax.vmap(hard_xor_include, 0, 0)(w, x)
22+
x = jax.vmap(hard_masks.hard_mask_to_false, 0, 0)(w, x)
3823
return jax.lax.reduce(x, False, jax.lax.bitwise_xor, [0])
3924

4025

@@ -48,8 +33,8 @@ class SoftXorLayer(nn.Module):
4833
layer_size: int
4934
weights_init: Callable = (
5035
nn.initializers.uniform(1.0)
51-
#hard_and.initialize_near_to_zero()
52-
)
36+
# hard_and.initialize_near_to_zero()
37+
)
5338
dtype: jax.numpy.dtype = jax.numpy.float32
5439

5540
@nn.compact

0 commit comments

Comments
 (0)