# -*- coding: utf-8 -*-
"""
Created on Sun Mar 15 17:24:04 2020

@author: Sophie
"""

import ctypes

from neticapy import enums
from neticapy import net as nt
from neticapy import neticaerror as err
from neticapy import rawerror
from neticapy.loaddll import Netica

# Initialize methods for keeping track of object uniqueness 
# Some features may not work with dict_initialization = False (multiple python 
# objects may be created for each Netica object)
dict_initialization = True
# Can be used in place of cptr_dict for Nodes and Nets
userdata_initialization = False

# Initialize dict that will contain the c pointers for all objects
# Search when initializing a new object to return an existing python object
if dict_initialization:
    cptr_dict = {}


# Initialize callbacks that will be used for responding to deletion events
CALLBACK = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_void_p, ctypes.c_int,
                            ctypes.c_void_p, ctypes.c_void_p)

@CALLBACK
def _net_callback(net, event, obj, data):
    if event == 4:                              # Deletion event
        if dict_initialization:
            if net in cptr_dict:
                py_net = cptr_dict[net]()
                if py_net:
                    del cptr_dict[py_net.cptr]
                    py_net.cptr = None
     
    return 0

@CALLBACK
def _node_callback(node, event, obj, data):
    if event == 4:                              # Deletion event
        if dict_initialization: 
            if node in cptr_dict:
                py_node = cptr_dict[node]()
                if py_node:
                    del cptr_dict[py_node.cptr]                     
                    py_node.cptr = None

    return 0

class Environ:

    def __init__(self, license_=None, environ=None, locn=None, _init=True):
        
        self.cptr = None # Initialize cptr for case where Netica raises an error during construction

        if license_ is not None:
            license_ = ctypes.c_char_p(license_.encode())
        if locn is not None:
            locn = ctypes.c_char_p(locn.encode())
            
        Netica.NewNeticaEnviron_ns.restype = ctypes.c_void_p
        global env 
        env = Netica.NewNeticaEnviron_ns(license_, environ, locn)
        self.cptr = env

        if _init is True:
            init_str = "\0" * enums._MESG_LEN_ns
            mesg = ctypes.c_char_p(init_str.encode())
                    
            Netica.InitNetica2_bn.restype = ctypes.c_int
            res = Netica.InitNetica2_bn(ctypes.c_void_p(self.cptr), mesg)
            mesg = mesg.value.decode()
            err.checkerr()

            # Check res to ensure res >= 0
            if res < 0:
                raise Exception(mesg)

            self.mesg = mesg

            Netica.SetLanguage_ns.restype = ctypes.c_char_p
            Netica.SetLanguage_ns(ctypes.c_void_p(self.cptr), 
                                ctypes.c_char_p("Python".encode()))
            err.checkerr()
            
            # Set up global listeners for callback deletion
            
            Netica.AddNetListener_bn.restype = None
            Netica.AddNetListener_bn(None, _net_callback, None, -1)
            
            Netica.AddNodeListener_bn.restype = None
            Netica.AddNodeListener_bn(None, _node_callback, None, -1)
        
            global UNDEF_DBL
            Netica.GetUndefDbl_ns.restype = ctypes.c_double
            UNDEF_DBL = Netica.GetUndefDbl_ns()
            
            global INFINITY
            Netica.GetInfinityDbl_ns.retype = ctypes.c_double
            INFINITY = Netica.GetInfinityDbl_ns()

            err.checkerr()

    # Someday, may want to allow user to initialize outside of __init__  
    # def init_netica(self):         
    #     """Initialize Netica Environment.
        
    #     This function calls the Netica-C functions InitNetica2_bn, SetLanguage_ns,
    #     AddNetListener_bn, AddNodeListener_bn, GetUndefDbl_ns, and 
    #     GetInfinityDbl_ns to set up a Netica environment through Python.
    #     """
    #     init_str = "\0" * enums._MESG_LEN_ns
    #     mesg = ctypes.c_char_p(init_str.encode())
                
    #     Netica.InitNetica2_bn.restype = ctypes.c_int
    #     res = Netica.InitNetica2_bn(ctypes.c_void_p(self.cptr), mesg)
    #     mesg = mesg.value.decode()
    #     err.checkerr()
        
    #     Netica.SetLanguage_ns.restype = ctypes.c_char_p
    #     Netica.SetLanguage_ns(ctypes.c_void_p(self.cptr), 
    #                           ctypes.c_char_p("Python".encode()))
    #     err.checkerr()
        
    #     # Set up global listeners for callback deletion
        
    #     Netica.AddNetListener_bn.restype = None
    #     Netica.AddNetListener_bn(None, _net_callback, None, -1)
        
    #     Netica.AddNodeListener_bn.restype = None
    #     Netica.AddNodeListener_bn(None, _node_callback, None, -1)
    
    #     global UNDEF_DBL
    #     Netica.GetUndefDbl_ns.restype = ctypes.c_double
    #     UNDEF_DBL = Netica.GetUndefDbl_ns()
        
    #     global INFINITY
    #     Netica.GetInfinityDbl_ns.retype = ctypes.c_double
    #     INFINITY = Netica.GetInfinityDbl_ns()

    #     err.checkerr()
    
    #     return res, mesg
        
    def close_netica(self):
        """Exit Netica (i.e. make it quit).  
        
        If a human user has entered unsaved changes, this will check with the 
        user first.  
        Like Netica-C CloseNetica_bn.
        """
        init_str = "\0" * enums._MESG_LEN_ns
        mesg = ctypes.c_char_p(init_str.encode())
        
        Netica.CloseNetica_bn.restype = ctypes.c_int
        res = Netica.CloseNetica_bn(ctypes.c_void_p(self.cptr), mesg)
        mesg = mesg.value.decode()
        global env
        env = None
        return res, mesg
    
    @property
    def version_num(self):
        """Return version number of Netica.
        
        Version is multiplied by 100, so version 3.24 would return 324.  
        Like Netica-C GetNeticaVersion_bn.
        """
        Netica.GetNeticaVersion_bn.restype = ctypes.c_int
        version_num = Netica.GetNeticaVersion_bn(ctypes.c_void_p(self.cptr), None) 
        err.checkerr()
        return version_num
    
    @property
    def version(self):
        """Return Netica version information.
        
        This consists of the full version number, a space, a code for the type 
        of machine it is running on, a comma, the name of the program, and 
        finally a code indicating some build information.
        Like Netica-C GetNeticaVersion_bn.
        """
        version = ctypes.c_char_p(b'')
        Netica.GetNeticaVersion_bn.restype = ctypes.c_int
        Netica.GetNeticaVersion_bn(ctypes.c_void_p(self.cptr), 
                                   ctypes.byref(version)) 
        err.checkerr()
        return version.value.decode()
    
    @property
    def argument_checking(self):
        """To what degree Netica functions check their arguments.  
        
        Like Netica-C ArgumentChecking_ns.
        """
        Netica.ArgumentChecking_ns.restype = ctypes.c_int
        setting = Netica.ArgumentChecking_ns(ctypes.c_int(enums.QUERY_ns), 
                                             ctypes.c_void_p(self.cptr))
        err.checkerr()
        return enums.get_argument_checking(setting)
    
    @argument_checking.setter
    def argument_checking(self, setting):
        """To what degree Netica functions check their arguments.  
        
        Like Netica-C ArgumentChecking_ns.
        """
        Netica.ArgumentChecking_ns.restype = ctypes.c_int
        Netica.ArgumentChecking_ns(ctypes.c_int(enums.set_argument_checking(setting)), 
                                   ctypes.c_void_p(self.cptr))
        err.checkerr()
    
    
    def set_password(self, password, options=None):
        """Sets Netica user password just for this session only.  
        
        Like the password passed to Netica-C NewNeticaEnviron_ns.
        Like Netica-C SetPassword_ns.
        """
        if options is not None:
            options = ctypes.c_char_p(options.encode())
        Netica.SetPassword_ns.restype = None
        Netica.SetPassword_ns(ctypes.c_void_p(self.cptr), 
                              ctypes.c_char_p(password.encode()), options)
        err.checkerr()
    
    @property
    def memory_usage_limit(self):
        """The maximum amount of memory Netica is allowed to use (in bytes).
        
        This includes all tables and lists.  
        When setting, max_mem is the number of bytes allowed. For example, to 
        limit memory usage to 100 Megabytes, pass 100e6.
        Like Netica-C LimitMemoryUsage_ns.
        """
        Netica.LimitMemoryUsage_ns.restype = ctypes.c_double
        mem_usage_limit = Netica.LimitMemoryUsage_ns(ctypes.c_double(enums.QUERY_ns), 
                                                     ctypes.c_void_p(self.cptr))
        err.checkerr()
        return mem_usage_limit
    
    @memory_usage_limit.setter
    def memory_usage_limit(self, max_mem):
        """The maximum amount of memory Netica is allowed to use (in bytes).
        
        This includes all tables and lists.  
        When setting, max_mem is the number of bytes allowed. For example, to 
        limit memory usage to 100 Megabytes, pass 100e6.
        Like Netica-C LimitMemoryUsage_ns.
        """
        Netica.LimitMemoryUsage_ns.restype = ctypes.c_double
        Netica.LimitMemoryUsage_ns(ctypes.c_double(max_mem), ctypes.c_void_p(self.cptr))
        err.checkerr()
        
    def test_fault_recovery(self, test_num):
        """Test Netica's fault recovery under extreme conditions.  
        
        Like Netica-C TestFaultRecovery_ns.
        """
        Netica.TestFaultRecovery_ns.restype = ctypes.c_int
        ret = Netica.TestFaultRecovery_ns(ctypes.c_void_p(self.cptr), ctypes.c_int(test_num))
        err.checkerr()
        return ret

    @property
    def case_file_delimiter(self):
        """The character to use as a delimeter when creating case files.  
        
        For newchar, pass the ascii character code. 
        It must be one of tab (9), space (32) or comma (44).
        Like Netica-C SetCaseFileDelimChar_ns.
        """
        Netica.SetCaseFileDelimChar_ns.restype = ctypes.c_int
        delimiter = Netica.SetCaseFileDelimChar_ns(ctypes.c_int(enums.QUERY_ns),
                                                   ctypes.c_void_p(self.cptr))
        err.checkerr()
        return chr(delimiter)
        
    @case_file_delimiter.setter
    def case_file_delimiter(self, newchar):
        """The character to use as a delimeter when creating case files.
        
        For newchar, pass the ascii character code. 
        It must be one of tab (9), space (32) or comma (44).
        Like Netica-C SetCaseFileDelimChar_ns.
        """
        if isinstance(newchar, str):
            newchar = ord(newchar)
        Netica.SetCaseFileDelimChar_ns.restype = ctypes.c_int
        Netica.SetCaseFileDelimChar_ns(ctypes.c_int(newchar), ctypes.c_void_p(self.cptr))
        err.checkerr()
    
    @property
    def case_file_missing_data_char(self):
        """The character to use to indicate missing data when creating case files.
        
        It must be one of asterisk * (42), question mark ? (63), space (32) or 
        absent (0).  
        Like Netica-C SetMissingDataChar_ns.
        """
        Netica.SetMissingDataChar_ns.restype = ctypes.c_int
        delimiter = Netica.SetMissingDataChar_ns(-1, ctypes.c_void_p(self.cptr))
        err.checkerr()
        return chr(delimiter)
    
    @case_file_missing_data_char.setter
    def case_file_missing_data_char(self, newchar):
        """The character to use to indicate missing data when creating case files.  
        
        It must be one of asterisk * (42), question mark ? (63), space (32) or 
        absent (0).  
        Like Netica-C SetMissingDataChar_ns.
        """
        if type(newchar) is str:
            newchar = ord(newchar)
        Netica.SetMissingDataChar_ns.restype = ctypes.c_int
        Netica.SetMissingDataChar_ns(ctypes.c_int(newchar), ctypes.c_void_p(self.cptr))
        err.checkerr()        
    
    def get_net(self, name):
        """Returns a BNet that Netica currently has in memory.  
        
        Can call funtion with either net name or integer index. If there are
        multiple nets with the same name, return first net of that name. When 
        called with an integer, each value returns a different net, where
        the first is 0. Function returns None if there is no net by the given
        name or there aren't as many nets as the int passed.
        Like Netica-C GetNthNet_bn.
        """ 
        if isinstance(name, int):
            Netica.GetNthNet_bn.restype = ctypes.c_void_p
            cptr = Netica.GetNthNet_bn(ctypes.c_int(name), 
                                        ctypes.c_void_p(self.cptr))
        elif isinstance(name, str):
            Netica.GetNthNet_bn.restype = ctypes.c_void_p
            i = 0
            while True:
                net = Netica.GetNthNet_bn(ctypes.c_int(i), 
                                          ctypes.c_void_p(self.cptr))               
                if net is None:
                    break
                
                # Check the name of the net without creating a net object in python
                Netica.GetNetName_bn.restype = ctypes.c_char_p
                net_name = Netica.GetNetName_bn(ctypes.c_void_p(net)).decode()
                               
                if name == net_name:
                    break
                i = i + 1                           
            cptr = net
        else:
            raise TypeError('An integer or string is required (got type {})'.format(type(name).__name__))
        
        err.checkerr()
        if cptr:
            return nt._create_net(cptr)
  
    def control_concurrency(self, command, value):
        """Control whether Netica operates single or multi-threaded, and how it 
        does its multi-threading.  
        
        Like Netica-C ControlConcurrency_ns.
        """
        Netica.ControlConcurrency_ns.restype = ctypes.c_char_p
        concurrency = Netica.ControlConcurrency_ns(ctypes.c_char_p(self.cptr), 
                                                   ctypes.c_char_p(command.encode()),
                                                   ctypes.c_char_p(value.encode()))
        err.checkerr()
        return concurrency.decode()
        
    def get_error(self, severity, after):
        """Will return a current RawNeticaError, or create a new RawNeticaError onject that 
        contains the information of the Netica error found by get_error.
        """
        if after is not None:
            if not isinstance(after, rawerror.RawNeticaError):
                raise TypeError('A RawNeticaError is required (got type {})'.format(type(after).__name__))
            after = ctypes.c_void_p(after.cptr)
        Netica.GetError_ns.restype = ctypes.c_void_p
        cptr = Netica.GetError_ns(ctypes.c_void_p(self.cptr), 
                                  ctypes.c_int(enums.set_error_severity(severity)), after)
        
        if rawerror.errdict_initialization:
            if cptr in rawerror.err_dict:
                return rawerror.err_dict[cptr]()
            else:
                return rawerror.RawNeticaError(('from_environ_get_error', cptr), None, None)
        else:
            return rawerror.RawNeticaError(('from_environ_get_error', cptr), None, None)
    
    def clear_errors(self, severity):
        """
        """
        Netica.ClearErrors_ns.restype = None
        Netica.ClearErrors_ns(ctypes.c_void_p(self), ctypes.c_int(enums.set_error_severity(severity)))
        if rawerror.errdict_initialization:
            for value in rawerror.err_dict.values():
                value().cptr = None
            rawerror.err_dict.clear()

    def _test_fault_recovery(self, test_num):
        """
        """
        Netica.TestFaultRecovery_ns.restype = ctypes.c_int
        res = Netica.TestFaultRecovery_ns(ctypes.c_void_p(self.cptr), ctypes.c_int(test_num))
        return res