

def init_stats():
    global num_samples # best way to do this?
    global sum_log_loss
    global sum_err_rate
    num_samples = 0 
	sum_log_loss = 0
	sum_err_rate = 0

def record_stats(log_loss, err_rate):
    sum_log_loss += log_loss
    sum_err_rate += err_rate
    num_samples += 1

def check_stats():
    if doprint:
		print(f"AVERAGES: log_loss = {sum_log_loss/num_samples:.5f}  err_rate = {sum_err_rate/num_samples:.5f}\n\n")
		# print(f"\nTOTALS: log_loss = {sum_log_loss:.5f}  err_rate = {sum_err_rate:.5f}\n")

int CheckStats (){
	if (doprint)  printf (  "AVERAGES: log_loss = %.5g    err_rate = %.5g\n\n", sum_log_loss / num_samples, sum_err_rate / num_samples);
	// if (doprint)  printf ("\nTOTALS:   log_loss = %.5g    err_rate = %.5g\n", sum_log_loss, sum_err_rate);
	CEQF (777, sum_log_loss / num_samples, 0.38361);
//	CEQF (777, sum_log_loss / num_samples, 0.405434);		// this is what it was with version Src 20-11-10.  Why?
	CEQF (777, sum_err_rate / num_samples, 0.174136);
//	CEQF (777, sum_err_rate / num_samples, 0.17414);		// 0.174136 != 0.17414 this is what it was with some minor change (when use learning_cases in the mem file instead of off disk file?).  Why?
//	CEQF (777, sum_err_rate / num_samples, 0.17953);		// this is what it was with version Src 20-11-10
	return 0;
	}


caseset_cs* SimulateCaseSet (const nodelist_bn* nodes, net_bn* net, int num_cases, double frac_missing, double frac_wrong, const char* case_file_name, randgen_ns* rand){
	int rc, case_num, cases_created = 0;
	caseset_cs* cases;
	stream_ns* mem_file = NewMemoryStream_ns ("RandCases", env, NULL);
	for (case_num = 0;  case_num < num_cases;  ++case_num){
		RetractNetFindings_bn (net);
		rc = GenerateRandomCase (nodes, DEFAULT_SAMPLING, frac_missing, frac_wrong, rand);
		if (rc >= 0){
			WriteNetFindings_bn (nodes, mem_file, case_num, 1.0);
			++cases_created;
			}
		}
	cases = NewCaseset_cs ("LearningCases", env);
	AddFileToCaseset_cs (cases, mem_file, 1.0, NULL);
	if (case_file_name){
		long_ns file_length;
		const char* cases_text = GetStreamContents_ns (mem_file, &file_length);
		errno = 0;
		FILE* file = fopen (case_file_name, "wb");						// not "w", coz CASES_TEXT already has correct line ends (/@CaseFileLineEnds)
		if (!file)  printf ("\n\n  *** Couldn't open file '%s' for writing (%s) ***\n\n", case_file_name, strerror(errno));
		else {
			fwrite (cases_text, 1, file_length, file);
			fclose (file);
			if (doprint)  printf ("Wrote %d cases to file '%s'\n", cases_created, case_file_name);
			}
		}
	return cases;
	}

//= Put version in NeticaEx

net_bn* LearnTanNet (caseset_cs* cases, const char* target_name, learn_method_bn method_for_cpts, const char* name, const char* save_as_filename){
	nodelist_bn* nodes = NULL;
	node_bn* target;
	learner_bn* learner;

	net_bn* net = NewNet_bn (name, env);
	nodes = NewNodeList2_bn (0, net);

	AddNodesFromCaseset_bn (net, cases, NULL, NULL, nodes, "auto_discretize");	// <------ the action
	CHK2(23846)
	
	// Note that nodes in the LearnedTan net may have fewer states than they did in the RAND_NET net, coz some unlikely states may not appear in the case file /@8

	CompareNodeLists (nodes, GetNetNodes2_bn (net, NULL), FALSE, NULL, 23894);	// just to check AddNodesFromCaseset_bn is setting NODES_ADDED correctly
	target = GetNodeNamed_bn (target_name, net);
	if (!target)  target = GetNthNode_bn (nodes, 0);		//@3
	LearnTanStructure_bn (nodes, target, cases, NULL, "auto_layout");			// <------ the action
	CHK2(23847)

	learner = NewLearner_bn (method_for_cpts, NULL, env);
	LearnCPTs_bn (learner, nodes, cases, 1.0);									// <------ the action
	CHK2(23848)

	if (save_as_filename){
		stream_ns* file_net = NewFileStream_ns (save_as_filename, env, NULL);
		WriteNet_bn (net, file_net);
		DeleteStream_ns (file_net);
		}
	DeleteLearner_bn (learner);
	DeleteNodeList_bn (nodes);
	CHK2(23845)
	return net;
	}

/*____________________________________________________________________________ LearnTanRandomCases
Given any case file, with one column called 'Class', this learns the TAN
  net to predict that class.
*/
int LearnTanRandomCases (const char fname[]){
	net_bn* rand_net = NULL, * net;
	const nodelist_bn* nodes = NULL;
	nodelist_bn* test_nodes;
	node_bn* target;
	stream_ns* file_cases;
	caseset_cs* cases, *learning_cases = NULL, *testing_cases = NULL;
	tester_bn* tester;

	const char* target_name = "Node_0";
	int max_num_nodes = 20, max_num_states = 4, max_num_parents = 4;
	int min_num_nodes = 1, min_num_states = 1, min_num_parents = 0;		// cant make nets with no nodes, or there wont be a target node  /@3
	int num_learning_cases = 200, num_testing_cases = 100;
	double frac_missing_learning = 0.2, frac_wrong_learning = 0;
	double frac_missing_testing = 0.6, frac_wrong_testing = 0;

# code actually starts

	rand_net = MakeRandomNet (rg1, min_num_nodes, max_num_nodes, min_num_states, max_num_states, min_num_parents, max_num_parents, FALSE);
	CompileNet_bn (rand_net);
	CHK(23891)
	nodes = GetNetNodes2_bn (rand_net, NULL);
	learning_cases = SimulateCaseSet (nodes, rand_net, num_learning_cases, frac_missing_learning, frac_wrong_learning, fname, rg1);
	CHK(23892)

	file_cases = NewFileStream_ns (fname, env, NULL);
	cases = NewCaseset_cs ("AllCases", env);
	AddFileToCaseset_cs (cases, file_cases, 1.0, NULL);

	strcpy (strstr (fname, ".cas"), ".dne"); # makes new file name ending in dne instead of cas

	net = LearnTanNet (cases, target_name, COUNTING_LEARNING, "LearnedTan", fname);	// <------ the action
	CHK(777)
	if (doprint)  printf ("Created net from case file '%s', having %d nodes\n", fname, LengthNodeList_bn (GetNetNodes2_bn (net, NULL)));
	CompileNet_bn (net);															// for TestWithCaseset_bn
	if (!nodes)  nodes = GetNetNodes2_bn (net, NULL);
	target = GetNodeNamed_bn (target_name, net);
	if (!target)  target = GetNthNode_bn (nodes, 0);		//@3

	test_nodes = NewNodeList2_bn (0, net);
	AddNodeToList_bn (target, test_nodes, 0);
	tester = NewNetTester_bn (test_nodes, NULL, -1);
	
	strcpy (strstr (fname, ".dne"), "_Testing.cas");		// WARNING: FNAME must be big enough

	testing_cases = SimulateCaseSet (nodes, rand_net, num_testing_cases, frac_missing_testing, frac_wrong_testing, fname, rg1);
	CHK(23893)
	TestWithCaseset_bn (tester, testing_cases);										// <------ the action
	CHK(23895)

	double err_rate = GetTestErrorRate_bn (tester, target);
	double log_loss = GetTestLogLoss_bn (tester, target);
	CHK(23896)
	printf ("\nError rate = %.5g      Log loss = %.5g\n\n", err_rate, log_loss);
	PrintConfusionMatrix1 (tester, target);
	CHK(23897)
	// const double* GetTestBinaryScore_bn (tester_bn* tester, const char* score, const node_bn *node, state_bn positive_state, double granularity, int* num_entries, const double** thresholds);
	RecordStats (log_loss, err_rate);
	# use numpy to plot ROC curve? want it to bulge up, if down invert decision - want big area under curve
	# log loss more important, makes sense to add over multiple bayes nets and take average is a reasonable measure
	# move record stats to utils?

	DeleteNetTester_bn (tester);
	DeleteNodeList_bn (test_nodes);
	DeleteNet_bn (net);
	DeleteNet_bn (rand_net);
	if (cases != testing_cases) DeleteCaseset_cs (cases);
	DeleteCaseset_cs (learning_cases);
	DeleteCaseset_cs (testing_cases);
	DeleteStream_ns (file_cases);
	return 0;
	}