# -*- coding: utf-8 -*-
"""
Build the Chest Clinic example Bayes net and save it to a 
file in the 'Data Files' directory.

For netica version 6.05 or later

@author: Sophie
"""

import sys, os
sys.path.append(os.path.abspath('..'))
import neticapy as netica

env = netica.Environ()
net = netica.Net(None, "Data Files/Learned_1000_Sit_Assess.neta", "NO_VISUAL_INFO")

# set nodes in network to a variable Nodes
Nodes = net.nodes

# initialize a node lists for case findings and set the target node
finding_nodes = netica.NodeList(0, net)
test_node = netica.NodeList(0, net)

# add all nodes to list for findings besides the target node
# for i in range(Nodes.length):
#     TempNode = Nodes.get_nth_node(i)
#     if TempNode.name != "Ave_of_Approach":
#         # set the target node to node named Ave_of_Approach
#         finding_nodes.add(TempNode,None)
#         # print the node name that is added as a finding node
#         print(TempNode.name)

# Alternative way:
# finding_nodes = net.nodes.copy()
# finding_nodes.remove("Ave_of_Approach")

testerTargetNode = net.get_node_named("Ave_of_Approach")

# if using net.Tester add Target node to node list
test_node.add(testerTargetNode,None)

net.retract_findings()      # IMPORTANT: Otherwise any findings will be part of tests !!
net.compile_net()   

tester = netica.Tester(test_node,finding_nodes, -1)

casefile = netica.Stream("Data Files/Situation Assessment Cases 10.cas")
caseset = netica.Caseset("TestCases")
caseset.add_cases_from_file(casefile, 1)

tester.test_with_cases(caseset)

def print_confusion_matrix(tester, node):
    numstates = range(node.num_states)
    print("\nConfusion matrix for {}:\n".format(node.name))
    for i in numstates:
        print(node.get_state_name(i).ljust(15), end='')
    print("Actual".ljust(15))
    for a in numstates:
        for p in numstates:
            print(repr(tester.get_confusion_matrix(node, p, a)).ljust(15), end='')
        print(node.get_state_name(a).ljust(15))
'''
def print_confusion_matrix(tester, node):
    numstates = range(node.num_states)
    print('numstates:', numstates)
    print("\nConfusion matrix for {}:\n".format(node.name))
    for i in numstates:
        print(node.get_state_name(i))
    for a in numstates:
        for p in numstates:
            print('a:', a, 'p:', p, 'cm:', tester.get_confusion_matrix(node, p, a))

'''


print_confusion_matrix(tester, testerTargetNode)

print("\nError rate for {} = {}%\n".format(testerTargetNode.name, 
      (tester.error_rate(testerTargetNode) * 100.0)))

print("Logarithmic loss for {} = {}\n".format(testerTargetNode.name, 
      tester.log_loss(testerTargetNode)))