

import ctypes
import weakref

from neticapy import enums
from neticapy import environ as envrn
from neticapy import stream as strm
from neticapy import net as nt
from neticapy import node as nd
from neticapy import nodelist as ndlst
from neticapy import neticaerror as err
from neticapy.loaddll import Netica

"""------------------------------Caseset Class------------------------------"""

class Caseset:

    def __init__(self, name, environ=None):
        """Create an initially empty set of cases.  
        
        'name' argument is not required (i.e. may be None).  
        Like Netica-C NewCaseset_cs.
        """       
        self.cptr = None # Initialize cptr for case where Netica raises an error during construction
        
        if name is not None:
            name = ctypes.c_char_p(name.encode())
             
        if environ is None:
            environ_cptr = envrn.env
        elif isinstance(environ, envrn.Environ):
            environ_cptr = environ.cptr
        else:
            raise TypeError('An Environ or None is required (got type {})'.format(type(environ).__name__))
        
        Netica.NewCaseset_cs.restype = ctypes.c_void_p
        cptr = Netica.NewCaseset_cs(name, ctypes.c_void_p(environ_cptr))
        err.checkerr()

        self.cptr = cptr
        
        if envrn.dict_initialization:
            envrn.cptr_dict[cptr] = weakref.ref(self)
    
    def __del__(self):
        """Remove this Caseset, freeing all resources it consumes, including memory.  
               
        Will not delete files this Caseset refers to.  
        Like Netica-C DeleteCaseset_cs.
        """
        if envrn.env is not None:
            Netica.DeleteCaseset_cs.restype = None
            Netica.DeleteCaseset_cs(ctypes.c_void_p(self.cptr))
            err.checkerr()
        if envrn.dict_initialization:
            del envrn.cptr_dict[self.cptr]
        self.cptr = None
    
    def add_cases_from_file(self, file, degree=1.0, options=None):
        """Add all the cases contained in 'file'.  
        
        'file' is a Stream
        'degree' is a multiplier for the multiplicty of each case.  
        'options' must be empty or None.  
        Like Netica-C AddFileToCaseset_cs.
        """
        if options is not None:
            options = ctypes.c_char_p(options.encode())
            
        if isinstance(file, strm.Stream):
            stream = file
        elif isinstance(file, str):
            stream = strm.Stream(file)
        else:
            raise TypeError('A Steam or filename is required (got type {})'.format(type(file).__name__))
            
        Netica.AddFileToCaseset_cs.restype = None
        Netica.AddFileToCaseset_cs(ctypes.c_void_p(self.cptr), ctypes.c_void_p(stream.cptr),
                                   ctypes.c_double(degree), options)
        err.checkerr()
    
    def add_cases_from_database(self, dbmgr, degree, nodes, column_names, tables, condition, options=None):
        """Adds to this Caseset the cases contained in db.  
        
        Like Netica-C AddDBCasesToCaseset_cs.
        """
        if column_names is not None:
            column_names = ctypes.c_char_p(column_names.encode())
        if tables is not None:
            tables = ctypes.c_char_p(tables.encode())
        if condition is not None:
            condition = ctypes.c_char_p(condition.encode())
        if options is not None:
            options = ctypes.c_char_p(options.encode())
        
        if not isinstance(dbmgr, DatabaseManager):
            raise TypeError('A DatabaseManager is required (got type {})'.format(type(dbmgr).__name__))
        if not isinstance(nodes, ndlst.NodeList):
            raise TypeError('A NodeList is required (got type {})'.format(type(nodes).__name__))
        
        Netica.AddDBCasesToCaseset_cs.restype = None
        Netica.AddDBCasesToCaseset_cs(ctypes.c_void_p(self.cptr), ctypes.c_void_p(dbmgr.cptr),
                                      ctypes.c_double(degree), ctypes.c_void_p(nodes.cptr),
                                      column_names, tables, condition, options)
        err.checkerr()

    def write(self, file, options=None):
        """Write all the cases to the indicated file (which may be on disk or in memory). 
        
        'options' must be empty or NULL.  
        Like Netica-C WriteCaseset_cs.
        """
        if options is not None:
            options = ctypes.c_char_p(options.encode())
            
        if not isinstance(file, strm.Stream):
            raise TypeError('A Stream is required (got type {})'.format(type(file).__name__))
            
        Netica.WriteCaseset_cs.restype = None
        Netica.WriteCaseset_cs(ctypes.c_void_p(self.cptr), 
                               ctypes.c_void_p(file.cptr), options)
        err.checkerr()

"""-----------------------------DatabaseManager Class-----------------------------"""

class DatabaseManager:
    
    def __init__(self, connect_str, options=None, environ=None):
        """Create a new database manager, and have it connect with database.  
        
        Like Netica-C NewDBManager_cs.
        """
        self.cptr = None # Initialize cptr for case where Netica raises an error during construction

        if environ is None:
            environ_cptr = envrn.env
        elif isinstance(environ, envrn.Environ):
            environ_cptr = environ.cptr
        else:
            raise TypeError('An Environ or None is required (got type {})'.format(type(environ).__name__))
            
        if options is not None:
            options = ctypes.c_char_p(options.encode())
        Netica.NewDBManager_cs.restype = ctypes.c_void_p
        cptr = Netica.NewDBManager_cs(ctypes.c_char_p(connect_str.encode()), 
                                      options, ctypes.c_void_p(environ_cptr))
        err.checkerr()

        self.cptr = cptr
        
        if envrn.dict_initialization:
            envrn.cptr_dict[cptr] = weakref.ref(self)

    def __del__(self):
        """"Remove this DatabaseManager.
        
        Closes connections and frees all the resources it consumes, including 
        memory.  
        Like Netica-C DeleteDBManager_cs.
        """
        if envrn.env is not None:
            Netica.DeleteDBManager_cs.restype = None
            Netica.DeleteDBManager_cs(ctypes.c_void_p(self.cptr))
            err.checkerr()
        if envrn.dict_initialization:
            del envrn.cptr_dict[self.cptr]
        self.cptr = None

    def execute_sql(self, sql_cmd, options=None):
        """Execute the passed SQL statement.  
        
        Like Netica-C ExecuteDBSql_cs.
        """
        if options is not None:
            options = ctypes.c_char_p(options.encode())
        Netica.ExecuteDBSql_cs.restype = None
        Netica.ExecuteDBSql_cs(ctypes.c_void_p(self.cptr), 
                               ctypes.c_char_p(sql_cmd.encode()), options)
        err.checkerr()
        
    def insert_findings(self, nodes, column_names, tables, options=None):
        """Add a new record to the database, consisting of the findings in 'nodes'.  
        
        If 'column_names' non-empty, it contains the database names for each of the variables in 'nodes'.  
        'tables' empty means use default table.  
        Like Netica-C InsertFindingsIntoDB_bn.
        """
        if options is not None:
            options = ctypes.c_char_p(options.encode())
            
        if not isinstance(nodes, ndlst.NodeList):
            raise TypeError('A NodeList is required (got type {})'.format(type(nodes).__name__))
            
        Netica.InsertFindingsIntoDB_bn.restype = None
        Netica.InsertFindingsIntoDB_bn(ctypes.c_void_p(self.cptr), ctypes.c_void_p(nodes.cptr),
                                       ctypes.c_char_p(column_names.encode()),
                                       ctypes.c_char_p(tables.encode()), options)
        err.checkerr()

"""------------------------------Learner Class------------------------------"""  
      
class Learner:
    
    def __init__(self, method, options=None, environ=None):
        """Create a new learner, which will learn using 'method'. 
        
        Method can be one of "COUNTING", "EM", or "GRADIENT_DESCENT".
        Like Netica-C NewLearner_bn.
        """
        self.cptr = None # Initialize cptr for case where Netica raises an error during construction

        if environ is None:
            environ_cptr = envrn.env
        elif isinstance(environ, envrn.Environ):
            environ_cptr = environ.cptr
        else:
            raise TypeError('An Environ or None is required (got type {})'.format(type(environ).__name__))
        
        Netica.NewLearner_bn.restype = ctypes.c_void_p
        cptr = Netica.NewLearner_bn(ctypes.c_int(enums.set_learning_method(method)), None, 
                                    ctypes.c_void_p(environ_cptr))
        err.checkerr()
        
        self.cptr = cptr
        
        if envrn.dict_initialization:
            envrn.cptr_dict[cptr] = weakref.ref(self)

    def __del__(self):
        """Remove this Learner, freeing all resources it consumes, including memory. 
        
        Like Netica-C DeleteLearner_bn.
        """
        if envrn.env is not None:
            Netica.DeleteLearner_bn.restype = None
            Netica.DeleteLearner_bn(ctypes.c_void_p(self.cptr))
            err.checkerr()
        if envrn.dict_initialization:
            del envrn.cptr_dict[self.cptr]
        self.cptr = None    
      
    @property
    def max_iterations(self):
        """Parameter to control termination.  
        
        The maximum number of learning step iterations.  
        Like Netica-C SetLearnerMaxIters_bn.
        """
        Netica.SetLearnerMaxIters_bn.restype = ctypes.c_int
        maxiters = Netica.SetLearnerMaxIters_bn(ctypes.c_void_p(self.cptr), ctypes.c_int(-1))
        err.checkerr()
        return maxiters
    
    @max_iterations.setter
    def max_iterations(self, max_iters):
        """Parameter to control termination.
        
        The maximum number of learning step iterations.  
        Like Netica-C SetLearnerMaxIters_bn.
        """
        Netica.SetLearnerMaxIters_bn.restype = ctypes.c_int
        Netica.SetLearnerMaxIters_bn(ctypes.c_void_p(self.cptr), ctypes.c_int(max_iters))
        err.checkerr()
    
    @property
    def max_tolerance(self):
        """Parameter to control termination.  
        
        When the log likelihood of the data given the model improves by less 
        than this, the model is considered close enough.  
        Like Netica-C SetLearnerMaxTol_bn.
        """
        Netica.SetLearnerMaxTol_bn.restype = ctypes.c_double
        maxtol = Netica.SetLearnerMaxTol_bn(ctypes.c_void_p(self.cptr), ctypes.c_double(enums.QUERY_ns))
        err.checkerr()
        return maxtol
    
    @max_tolerance.setter
    def max_tolerance(self, max_tol):
        """Parameter to control termination.
        
        When the log likelihood of the data given the model improves by less 
        than this, the model is considered close enough.  
        Like Netica-C SetLearnerMaxTol_bn.
        """
        Netica.SetLearnerMaxTol_bn.restype = ctypes.c_double
        Netica.SetLearnerMaxTol_bn(ctypes.c_void_p(self.cptr), ctypes.c_double(max_tol))
        err.checkerr()
      
    def learn_CPTs(self, nodes, cases, degree):
        """Modify the CPTs (and experience tables) for the nodes in 'nodes', 
        to take into account the case data from 'cases' (with multiplicity multiplier 'degree', which is normally 1).  
        Like Netica-C LearnCPTs_bn.
        """
        if not isinstance(nodes, ndlst.NodeList):
            raise TypeError('A NodeList is required (got type {})'.format(type(nodes).__name__))
        if not isinstance(cases, Caseset):
            raise TypeError('A Caseset is required (got type {})'.format(type(cases).__name__))
            
        Netica.LearnCPTs_bn.restype = None
        Netica.LearnCPTs_bn(ctypes.c_void_p(self.cptr), ctypes.c_void_p(nodes.cptr),
                            ctypes.c_void_p(cases.cptr), ctypes.c_double(degree))
        err.checkerr()

"""------------------------------Tester Class-------------------------------"""  
    
class Tester:
    
    def __init__(self, test_nodes, unobsv_nodes, tests=-1):
        """Create a Tester to performance test this net using a set of cases.  
        
        Like Netica-C NewNetTester_bn.
        """
        self.cptr = None # Initialize cptr for case where Netica raises an error during construction

        if not isinstance(test_nodes, ndlst.NodeList):
            raise TypeError('A NodeList is required (got type {})'.format(type(test_nodes).__name__))     
        if unobsv_nodes is not None:
            if not isinstance(unobsv_nodes, ndlst.NodeList):
                raise TypeError('A NodeList is required (got type {})'.format(type(unobsv_nodes).__name__))
            unobsv_nodes = ctypes.c_void_p(unobsv_nodes.cptr)

        Netica.NewNetTester_bn.restype = ctypes.c_void_p
        cptr = Netica.NewNetTester_bn(ctypes.c_void_p(test_nodes.cptr),
                                      unobsv_nodes, ctypes.c_int(tests))
        err.checkerr()
        
        self.cptr = cptr
        
        if envrn.dict_initialization:
            envrn.cptr_dict[cptr] = weakref.ref(self)
   
    def __del__(self):
        """Remove this net tester, freeing all resources it consumes, including memory.  
        
        Like Netica-C DeleteNetTester_bn.
        """
        if envrn.env is not None:
            Netica.DeleteNetTester_bn.restype = None
            Netica.DeleteNetTester_bn(ctypes.c_void_p(self.cptr))
            err.checkerr()
        if envrn.dict_initialization:
            del envrn.cptr_dict[self.cptr]
        self.cptr = None

    def test_with_cases(self, cases):
        """Scan through the data in 'cases', and for each case check the values 
        in the case for test_nodes against the predictions made by the net 
        based on the other values in the case.  
        
        test_nodes is set with BNet.NewNetTester.  
        Like Netica-C TestWithCaseset_bn
        """
        if not isinstance(cases, Caseset):
            raise TypeError('A Caseset is required (got type {})'.format(type(cases).__name__))
        Netica.TestWithCaseset_bn.restype = None
        Netica.TestWithCaseset_bn(ctypes.c_void_p(self.cptr), ctypes.c_void_p(cases.cptr))
        err.checkerr()

    def error_rate(self, node):
        """Return the fraction of test cases where the net predicted the wrong state.  
        
        Like Netica-C GetTestErrorRate_bn.
        """
        if not isinstance(node, nd.Node):
            raise TypeError('A Node is required (got type {})'.format(type(node).__name__))
        Netica.GetTestErrorRate_bn.restype = ctypes.c_double
        error_rate = Netica.GetTestErrorRate_bn(ctypes.c_void_p(self.cptr), 
                                                ctypes.c_void_p(node.cptr))
        err.checkerr()
        return error_rate

    def log_loss(self, node):
        """The 'logarithmic loss', which for each case takes into account the 
        prediction probability the net gives to the state that turns out to be correct.  
        
        Ranges from 0 (perfect score) to infinity.  
        Like Netica-C GetTestLogLoss_bn.
        """
        if not isinstance(node, nd.Node):
            raise TypeError('A Node is required (got type {})'.format(type(node).__name__))
        Netica.GetTestLogLoss_bn.restype = ctypes.c_double
        log_loss = Netica.GetTestLogLoss_bn(ctypes.c_void_p(self.cptr), 
                                            ctypes.c_void_p(node.cptr))
        err.checkerr()
        return log_loss

    def quadratic_loss(self, node):
        """The 'quadratic loss', also known as 'Brier score' for 'node' under 
        the test performed by TestWithCases.  
        
        Like Netica-C GetTestQuadraticLoss_bn.
        """
        if not isinstance(node, nd.Node):
            raise TypeError('A Node is required (got type {})'.format(type(node).__name__))
        Netica.GetTestQuadraticLoss_bn.restype = ctypes.c_double
        quadratic_loss = Netica.GetTestQuadraticLoss_bn(ctypes.c_void_p(self.cptr),
                                                        ctypes.c_void_p(node.cptr))
        err.checkerr()
        return quadratic_loss
        
    def get_confusion_matrix(self, node, predicted, actual):
        """Return an element of the 'confusion matrix'. 
        
        Element is the number of times the net predicted 'predicted_state' for 
        node, but the case file actually held 'actual_state' as the value of 
        that node.  
        Like Netica-C GetTestConfusion_bn.
        """
        if not isinstance(node, nd.Node):
            raise TypeError('A Node is required (got type {})'.format(type(node).__name__))
        if isinstance(predicted, str):
            predicted = node.get_state_named(predicted)
        if isinstance(actual, str):
            actual = node.get_state_named(actual)
        Netica.GetTestConfusion_bn.restype = ctypes.c_double
        test_confusion = Netica.GetTestConfusion_bn(ctypes.c_void_p(self.cptr), 
                                                    ctypes.c_void_p(node.cptr),
                                                    ctypes.c_int(predicted), 
                                                    ctypes.c_int(actual))
        err.checkerr()
        return test_confusion

    def binary_score(self, score, node, positive_state, granularity=-1):
        """***nodocs
        """
        if not isinstance(node, nd.Node):
            raise TypeError('A Node is required (got type {})'.format(type(node).__name__))
        cnum_entries = ctypes.c_int(0)
        thresholds_ref = ctypes.POINTER(ctypes.c_double)()
        Netica.GetTestBinaryScore_bn.restype = ctypes.POINTER(ctypes.c_double)
        score_pointer = Netica.GetTestBinaryScore_bn(ctypes.c_void_p(self.cptr), 
                                                     ctypes.c_char_p(score.encode()),
                                                     ctypes.c_void_p(node.cptr), 
                                                     ctypes.c_int(positive_state),
                                                     ctypes.c_double(granularity), 
                                                     ctypes.byref(cnum_entries),
                                                     ctypes.byref(thresholds_ref))
        err.checkerr()

        binary_score = score_pointer[0]        
        num_entries = cnum_entries.value
        
        thresholds = []
        for i in range(num_entries):
            thresholds.append(thresholds_ref[i])
        
        return binary_score, thresholds
        
        

"""--------------------------RandomGenerator Class--------------------------"""

class RandomGenerator:
    
    def __init__(self, seed, options=None, environ=None):
        """Create a new RandomGenerator.
        
        For seed, pass a positive (or zero) integer in the form of a string.
        options should either be None or the string "Nondeterministic".
        Like Netica-C NewRandomGenerator_ns.
        """   
        self.cptr = None # Initialize cptr for case where Netica raises an error during construction
        
        if environ is None:
            environ_cptr = envrn.env
        elif isinstance(environ, envrn.Environ):
            environ_cptr = environ.cptr
        else:
            raise TypeError('An Environ or None is required (got type {})'.format(type(environ).__name__))
        
        if options is not None:
            options = ctypes.c_char_p(options.encode())
        Netica.NewRandomGenerator_ns.restype = ctypes.c_void_p
        cptr = Netica.NewRandomGenerator_ns(ctypes.c_char_p(seed.encode()), 
                                            ctypes.c_void_p(environ_cptr), options)
        err.checkerr()

        self.cptr = cptr
        
        if envrn.dict_initialization:
            envrn.cptr_dict[cptr] = weakref.ref(self)
        
    def __del__(self):
        """Remove this random generator, freeing all resources it consumes, including memory.  
        
        Like Netica-C DeleteRandomGen_ns.
        """
        if envrn.env is not None:
            Netica.DeleteRandomGen_ns.restype = None
            Netica.DeleteRandomGen_ns(ctypes.c_void_p(self.cptr))
            err.checkerr()
        if envrn.dict_initialization:
            del envrn.cptr_dict[self.cptr]
        self.cptr = None    
    
    def get_state(self, options=None):
        """***nodocs
        
        Like Netica-C GetRandomGenState_ns
        """
        if options is not None:
            options = ctypes.c_char_p(options.encode())
        Netica.GetRandomGenState_ns.restype = ctypes.c_char_p
        seed = Netica.GetRandomGenState_ns(ctypes.c_void_p(self.cptr), options)
        err.checkerr()
        return seed.decode()

    def generate_random_numbers(self, num, options=None):
        """Generate num pseudo-random numbers using rand and return in an array
        
        The numbers will be between 0 (inclusive) and 1 (exclusive), that is, 
        from the interval [0,1).
        Like Netica-C GenerateRandomNumbers_ns
        """
        if options is not None:
            options = ctypes.c_char_p(options.encode())
        #results = ctypes.c_double(0)
        #results_ref = ctypes.POINTER(results)
        array = [0] * num
        results = (ctypes.c_double*num)(*array)
        Netica.GenerateRandomNumbers_ns.restype = ctypes.c_double
        Netica.GenerateRandomNumbers_ns(ctypes.c_void_p(self.cptr), results, 
                                        ctypes.c_int(num), options)
        err.checkerr()    
        rand_nums = []
        for i in range(num):
            rand_nums.append(results[i])
        return rand_nums

    
"""----------------------------Sensitivity Class----------------------------"""

class Sensitivity:
    
    def __init__(self, t_node, findings_nodes, what_calc):
        """Create a sensitivity measurer to determine how much this node could 
        be affected by new findings at certain other nodes. 
        
        Like Netica-C NewSensvToFinding_bn.
        """
        self.cptr = None # Initialize cptr for case where Netica raises an error during construction

        if not isinstance(findings_nodes, ndlst.NodeList):
            raise TypeError('A NodeList is required (got type {})'.format(type(findings_nodes).__name__))
        Netica.NewSensvToFinding_bn.restype = ctypes.c_void_p
        cptr = Netica.NewSensvToFinding_bn(ctypes.c_void_p(t_node.cptr), 
                                           ctypes.c_void_p(findings_nodes.cptr),
                                           ctypes.c_int(enums.set_sensv(what_calc))) 
        err.checkerr()
        
        self.cptr = cptr
        
        if envrn.dict_initialization:
            envrn.cptr_dict[cptr] = weakref.ref(self)
            
        self.node = t_node
        
    def __del__(self):
        """Remove this sensitivity measurer, freeing all resources it consumes, including memory.  
        
        Like Netica-C DeleteSensvToFinding_bn.
        """
        if envrn.env is not None:
            Netica.DeleteSensvToFinding_bn.restype = None
            Netica.DeleteSensvToFinding_bn(ctypes.c_void_p(self.cptr))
            err.checkerr()
        if envrn.dict_initialization:
            del envrn.cptr_dict[self.cptr]
        self.cptr = None
    
    def get_mutual_info(self, finding_node):
        """Return the mutual information between q_node and finding_node.
        
        I.e. expected reduction in entropy of q_node due to finding at 
        finding_node.  Create this Sensitivity object with:  
        q_node.new_sensitivity (EntropyMeasure, ..).  
        Like Netica-C GetMutualInfo_bn.
        """
        if not isinstance(finding_node, nd.Node):
            raise TypeError('A Node is required (got type {})'.format(type(finding_node).__name__))
        Netica.GetSensvMutualInfo_bn.restype = ctypes.c_double
        mutual_info = Netica.GetSensvMutualInfo_bn(ctypes.c_void_p(self.cptr), 
                                              ctypes.c_void_p(finding_node.cptr))
        err.checkerr()
        return mutual_info  

    def get_variance_of_real(self, finding_node):
        """The expected change squared in the expected real value of query_node, 
        if a finding was obtained for finding_node.  
        
        Create this Sensitivity object with:  query_node.NewSensitivity (
        RealMeasure+VarianceMeasure, ..).  
        Like Netica-C GetVarianceOfReal_bn.
        """
        if not isinstance(finding_node, nd.Node):
            raise TypeError('A Node is required (got type {})'.format(type(finding_node).__name__))
        Netica.GetVarianceOfReal_bn.restype = ctypes.c_double
        variance_of_real = Netica.GetVarianceOfReal_bn(ctypes.c_void_p(self.cptr),
                                                       ctypes.c_void_p(finding_node.cptr))
        err.checkerr()
        return variance_of_real  