
import sys, os
sys.path.append(os.path.abspath('..'))
import neticapy as netica

import random

show = True

def next_states(states, nodes):
    """Cycle trough all possible configurations of states.

    The configurations of states are the the elements of the 
    cartesian product, with the last state changing most rapidly. 
    It returns True when all the configurations have been examined.
    Don't forget to initialize STATES before calling it the first 
    time (usually to all zeros).
    Like Netica-Ex NextStates
    """
    for nn in range(nodes.length - 1, -1, -1):      # for each parent
        #print("nn:", nn)
        #print("states[nn] + 1:", states[nn] + 1)
        #print("nodes.get_nth_node(nn).num_states:", nodes.get_nth_node(nn).num_states)
        states[nn] += 1
        if states[nn] < nodes.get_nth_node(nn).num_states:
            return False
        states[nn] = 0 # What does this do?
    #print("Exited for loop")
    return True

"""        
STATES is a list of node states, one for each node of NODES.
This cycles through all possible configurations (i.e. elements of the cartesian
product) of STATES, odometer style, with the last state changing fastest.
It returns TRUE when all the configurations have been examined (i.e., when it
"rolls over" to all zeros again).
Don't forget to initialize STATES before calling it the first time (usually 
to all zeros).


                bool_ns NextStates (state_bn* states, const nodelist_bn* nodes){
    int nn;
    for (nn = LengthNodeList_bn (nodes) - 1;  nn >= 0;  --nn){      # for each parent
        if (++states[nn] < GetNodeNumberStates_bn (NthNode_bn (nodes, nn)))  # 
            return FALSE;
        states[nn] = 0;
        }
    return TRUE;
    }

"""

def size_cartesian_product(nodes):
    """Return the size of the cartesian product of the states of nodes.
 
    Return 0 if one of nodes is continuous and not discretized.
    Return DBL_MAX if the size is greater than DBL_MAX (this type of overflow is not
    uncommon, since the values returned can be very large).
    See NextStates to scan through them odometer style.
    """
    size = 1

    for nn in range(nodes.length):
        num_states = nodes.get_nth_node(nn).num_states
        if num_states == 0:
            return 0
        # if (num_states + 1 > DBL_MAX / size )
        #     return DBL_MAX
        size *= num_states
    return size

def test_password(password):
    return password

def random_int(max_int):
    # *** placeholder?
    
    return random.randint(0, max_int - 1)

def random_dbl(num):            # ?????

    return random.random()      #Return the next random floating point number in the range [0.0, 1.0).

# make random net


def make_random_net(rand, min_num_nodes, max_num_nodes, min_num_states, max_num_states, min_num_parents, max_num_parents, infer_test):
    """
    """
    # Include this in demo code?
    assert 0 <= min_num_nodes and min_num_nodes <= max_num_nodes, "0 <= min nodes <= max nodes"
    assert 1 <= min_num_states and min_num_states <= max_num_states, "1 <= min states <= max states"
    assert 0 <= min_num_parents and min_num_parents <= max_num_parents, "0 <= min parents <= max parents"

    net = netica.Net("rand")
    net.set_num_undos_kept(0, 0) # this makes it faster, but comment it out occasionally to test Undo system

    # Randomly choose how many nodes to make
    rg1 = rand
    numnodes = min_num_nodes + random_int(max_num_nodes - min_num_nodes + 1)

    for nn in range(numnodes):
        # Make new node with random number of states
        if infer_test:
            if rg1.generate_random_numbers(1)[0] < (1/40):
                numstates = 1
            else:
                numstates = 2 + random_int(max_num_states - 1)
        else:
            if max_num_states == 1:
                numstates = 1
            elif min_num_states == 1 and rg1.generate_random_numbers(1)[0] < (1/40):   # we only want a very few nodes with 1 state, for test purposes
                numstates = 1
            else:
                numstates = max(min_num_states, 2) + random_int(max_num_states - min_num_states)   # we only want a very few nodes with 1 state, for test purposes
        
        name = f"node_{nn}"
        node = net.new_node(name, numstates)

        # Randomly connect new node with some nodes already made

        # if rg1.generate_random_numbers(1) < (1/6):
        #     numparents = 0
        # elif rg1.generate_random_numbers(1) < (1/40):
        #     numparents = 4
        # else:
        #     numparents = 1 + random_int(3)
        
        numparents_try = min_num_parents + random_int(max_num_parents - min_num_parents + 1)

        if numparents_try > nn:
            numparents_try = nn    # the first node must have 0 parents, the second a max of 1, etc.
        
        # haven't implemented user data
        #SetNodeUserData_bn (node, 0, node);						// prevents links from node to itself (cant just use smaller parent_index, coz we dont _know_ 'node' was added to the end)
        
        # we're setting the user data of the node to itself? Can we just check cptr? 

        numparents = 0

        for pn in range(numparents_try):
            parent_index = random_int(nn + 1)
            parent = net.nodes.get_nth_node(parent_index)
            if True:                                        # disallows multiple links between the same nodes
                if parent is not node:
                    if parent not in node.parents.nodes:
                        node.add_link_from(parent)
                        numparents += 1

            else:                                           # allows multiple links between the same nodes
                node.add_link_from(parent)
                numparents += 1

        if show:
            print(f"   p{numparents} s{numstates}")

        # Make random probability distribution for new node

        #parent_states = [0] * max_num_parents
        #parent_states = []
        parent_states = [0] * numparents
        #probs = [0] * max_num_states

        for pn in range(numparents):
            parent_states[pn] = 0
            #parent_states.append(0)
            finished = False
            while not finished:   
                sum = 0
                probs = [0] * numstates
                for sn in range(numstates):
                    if rg1.generate_random_numbers(1)[0] < (1/10):    # makes sure there are some 0s
                        probs[sn] = 0
                    else:
                        probs[sn] = pow(random_dbl(1.0), 3)
                        sum += probs[sn]  #?? else sum += (probs[sn] = pow (random_dbl(1.0), 3));
                if sum == 0:
                    probs[0] = 1
                else:
                    for sn in range(numstates):      # normalize
                        probs[sn] /= sum
                node.set_table("cpt", parent_states, probs)
                              
                #print("parent_states", parent_states)
                #print("node.parents.length", node.parents.length)
                #print(size_cartesian_product(node.parents))
                finished = next_states(parent_states, node.parents)



    if show:
        print("\n\n")
    return net

env = netica.Environ()

random_net = make_random_net(netica.RandomGenerator("17"), 5, 15, 1, 4, 0, 4, True) # 20 -> free version
random_net.write(netica.Stream("Data Files/RandomNet.dne"))


# simulate cases

# learn tan net


#____________________________________________________________________ LearnTanRandomCases_TestAPI

def learn_tan_random_cases_test_API():
	dirname = "WDIR" # should be pointer to this? choose my own working dir name?
	num_tests = 10
	if show:
		print(f"\nLEARN RANDOM TAN TEST:\n---------------------\n")

    env = Netica.Environ(test_password("+Test1/Norsys/all/39512"))
    # InitStats(); ?
    file_num = 1                    # this can be set to any value to study specific nets
    for itest in range(num_tests):
        if show:
            print(f"\n-------- Testing Net {file_num} --------\n\n")
            print(f"{fname}.cas {dirname} RandomTan, {file_Num}")


	for (itest = 0;  itest < num_tests;  ++itest){
		if (doprint)  printf ("\n-------- Testing Net %d --------\n\n", (int) File_Num);
		sprintf (fname, "%s%s%d.cas", dirname, "RandomTan", File_Num);
		sprintf (seed, "%d", File_Num + 2000);						// so we can set the initial value of 'File_Num' to other values to start at any net in the grand sequence /@6
		randgen_ns* old_rg1 = rg1;
		rg1 = NewRandomGenerator_ns (seed, NULL, NULL);
		res= LearnTanRandomCases (fname);
		DeleteRandomGen_ns (rg1);
		rg1 = old_rg1;
		if (res < 0)  return res;									// error already displayed
		File_Num ++;
		CHK(23888)													// just for safety
		}
	if (doprint)  printf ("\nRan %d tests of making random net, simulating case-set, learning TAN net and testing it.\n\n", (int) num_tests);
	CheckStats();

	CHK(23889)
	res= CloseNetica_bn (env, mesg);
	if (res < 0)  FAIL ((23890, "Close Netica failure: %s", mesg))
	return res;
	}

//= Put version in NeticaEx

"""
#____________________________________________________________________ LearnTanRandomCases_TestAPI

int LearnTanRandomCases_TestAPI (void){
	char fname[300];
	int itest, res;
	char mesg[MESG_LEN_ns], seed[19+1];
	const char* dirname = WDIR;					// was  WDIR"TAN Learning Test/"  but on Marc's Linux fopen couldnt create a directory, and I didnt want to pre-put in the directory since the WDIR (ie, 'work') directory gets periodically emptied
	const int num_tests = 10;
	
	if (show)  printf ("\nLEARN RANDOM TAN TEST:\n---------------------\n");
	env = NewNeticaEnviron_ns (TEST_PASSWORD("+Test1/Norsys/all/39512"), 0, 0);
	res = InitNetica2_bn (env, mesg);
	if (res < 0)  FAIL ((23887, "Init Netica failure: %s", mesg))
	
	InitStats();
	File_Num = 1;													// this can be set to any value to study specific nets /@6
	for (itest = 0;  itest < num_tests;  ++itest){
		if (doprint)  printf ("\n-------- Testing Net %d --------\n\n", (int) File_Num);
		sprintf (fname, "%s%s%d.cas", dirname, "RandomTan", File_Num);
		sprintf (seed, "%d", File_Num + 2000);						// so we can set the initial value of 'File_Num' to other values to start at any net in the grand sequence /@6
		randgen_ns* old_rg1 = rg1;
		rg1 = NewRandomGenerator_ns (seed, NULL, NULL);
		res= LearnTanRandomCases (fname);
		DeleteRandomGen_ns (rg1);
		rg1 = old_rg1;
		if (res < 0)  return res;									// error already displayed
		File_Num ++;
		CHK(23888)													// just for safety
		}
	if (doprint)  printf ("\nRan %d tests of making random net, simulating case-set, learning TAN net and testing it.\n\n", (int) num_tests);
	CheckStats();

	CHK(23889)
	res= CloseNetica_bn (env, mesg);
	if (res < 0)  FAIL ((23890, "Close Netica failure: %s", mesg))
	return res;
	}

//= Put version in NeticaEx
"""

# print confusion matrix?