Abracadabra

Dynet xor demo [python version]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import dynet as dy
import random
# Parameters of the model and training
HIDDEN_SIZE = 20
NUM_EPOCHS = 20
# Define the model and SGD optimizer
model = dy.Model()
W_xh_p = model.add_parameters((HIDDEN_SIZE, 2))
b_h_p = model.add_parameters(HIDDEN_SIZE)
W_hy_p = model.add_parameters((1, HIDDEN_SIZE))
b_y_p = model.add_parameters(1)
trainer = dy.SimpleSGDTrainer(model)
# Define the training data, consisting of (x,y) tuples
data = [([1,1],1), ([-1,1],-1), ([1,-1],-1), ([-1,-1],1)]
# Define the function we would like to calculate
def calc_function(x):
dy.renew_cg()
w_xh = dy.parameter(W_xh_p)
b_h = dy.parameter(b_h_p)
W_hy = dy.parameter(W_hy_p)
b_y = dy.parameter(b_y_p)
x_val = dy.inputVector(x)
h_val = dy.tanh(w_xh * x_val + b_h)
y_val = W_hy * h_val + b_y
return y_val
# Perform training
for epoch in range(NUM_EPOCHS):
epoch_loss = 0
random.shuffle(data)
for x, ystar in data:
y = calc_function(x)
loss = dy.squared_distance(y, dy.scalarInput(ystar))
epoch_loss += loss.value()
loss.backward()
trainer.update()
print("Epoch %d: loss=%f" % (epoch, epoch_loss))
# Print results of prediction
for x, ystar in data:
y = calc_function(x)
print("%r -> %f" % (x, y.value()))

Output:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
[dynet] random seed: 1174664263
[dynet] allocating memory: 512MB
[dynet] memory allocation done.
Epoch 0: loss=12.391680
Epoch 1: loss=8.196088
Epoch 2: loss=8.103037
Epoch 3: loss=8.636450
Epoch 4: loss=7.573008
Epoch 5: loss=4.910318
Epoch 6: loss=3.079966
Epoch 7: loss=1.328273
Epoch 8: loss=1.171368
Epoch 9: loss=0.515850
Epoch 10: loss=1.885216
Epoch 11: loss=0.568994
Epoch 12: loss=0.278629
Epoch 13: loss=0.025215
Epoch 14: loss=0.018466
Epoch 15: loss=0.055305
Epoch 16: loss=0.014131
Epoch 17: loss=0.010476
Epoch 18: loss=0.003893
Epoch 19: loss=0.003332
[1, 1] -> 1.049703
[-1, 1] -> -0.996379
[1, -1] -> -0.974599
[-1, -1] -> 0.995763