# -*- coding: utf-8 -*-
"""
Build the Bouncing example Bayes net and save it to a 
file in the 'Data Files' directory.

For netica version 6.05 or later

@author: Sophie
"""

import sys, os
sys.path.append(os.path.abspath('..'))
import neticapy as netica

env = netica.Environ()
print(env.mesg)

net = netica.Net("Bouncing")

position = net.new_node("x", 0)
velocity = net.new_node("v", 0)
dt = net.new_node("dt", 0, kind="CONSTANT")

position.title = "Position"
velocity.title = "Velocity"

position.set_vis_position(130, 100)
velocity.set_vis_position(130, 300)
dt.set_vis_position(0, 50)

position.intervals = [0, 0, 2, 4, 6, 8, 10, 10]
velocity.intervals = range(-10, 12, 2)

position.add_link_from(velocity)
position.add_link_from(position)
velocity.add_link_from(position)
velocity.add_link_from(velocity)

dt.enter_value(0.25)

position.set_input_name(0, "vd")
position.set_input_name(1, "xd")
velocity.set_input_name(0, "xd")
velocity.set_input_name(1, "vd")

# For Netica-Py versions 1.00 and earlier
# position.set_input_delay(0, 0, "dt")
# position.set_input_delay(1, 0, "dt")
# velocity.set_input_delay(0, 0, "dt")
# velocity.set_input_delay(1, 0, "dt")

position.set_input_delay(0, "dt")
position.set_input_delay(1, "dt")
velocity.set_input_delay(0, "dt")
velocity.set_input_delay(1, "dt")

position.equation = "x (xd, vd) = clip (0, 10, xd + dt * vd)"
velocity.equation = "v (xd, vd) = (xd == 0  &&  vd < 0  ||  xd == 10  &&  vd > 0) ? -vd : vd"

position.equation_to_table(10000, False, False)
velocity.equation_to_table(10000, False, False)

expnet = net.expand_time_series(1, 2)

expnet.compile_net()

x9 = expnet.get_node_named('x9')
x12 = expnet.get_node_named('x12')
v8 = expnet.get_node_named('v8')

x9.enter_value(1)
x12.enter_value(3.7)

print("Initial velocity intervals:", v8.intervals)
print("Probabilities of initial velocity:", v8.get_beliefs())

# Save the unexpanded net to file
net.write(netica.Stream("Data Files/Bouncing.dne"))

print("Built the 'Bouncing' Bayes net and saved it as file 'Data Files/Bouncing.dne'")

res, mesg = env.close_netica()
print('\n' + mesg)