// eval4.cpp

#include <ctype.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <signal.h>
#include <float.h>
#include "eval4.h"

static const int	MaxTokenLen		=  10,
					MaxStmtLen		= 100,
					MaxStackDepth	=  10,
					MaxSymbols		=  20;

typedef enum {
	UnknownS_e,
	skip_e,		// Discard statement, but continue program.
	stop_e		// Terminate program.
} severity_t;
typedef char	token_t[MaxTokenLen+1];	// Leave one extra character for the terminating null.
typedef enum {
	UnknownTT_e, identifier_e, number_e
} TokenType_t;
//typedef double	value_t;
typedef struct {
	token_t	name;
	value_t	value;
} symbol_t;

static const value_t	PI = 3.14159265358979323846,
						E  = 2.71828182845904523536;

static char				statement[MaxStmtLen+1];
static int				StmtLen,
						StmtIndex;

static token_t			token;
static value_t			TokenValue;	// For storing the value of a literal numeric constant.
static TokenType_t		TokenType;

static value_t			stack[MaxStackDepth];
static int				StackDepth;			// Indicates the first empty place on the stack.
static symbol_t			symbols[MaxSymbols] = {
							{"rand",	0.0},
							{"abs",		1.0},
							{"sin",		1.0},
							{"cos",		1.0},
							{"tan",		1.0},
							{"asin",	1.0},
							{"acos",	1.0},
							{"atan",	1.0},
							{"atan2",	2.0},
							{"e",		E},
							{"pi",		PI},
							{"deg",		180.0/PI},
							{"rad",		PI/180.0},
						};
static int				NumFunctions =  9,
						NumSymbols   = 13;

static bool				GotToken;

// Support routines. -----------------------------------------------------------

static bool equal (token_t a, token_t b) {
	if (!a && !b) return false;
	if (!a || !b) return true;
	return strcmp (a, b) == 0;
}

static void error (severity_t severity, char *msg1, char *msg2 = 0) {
	if (msg2 != 0) {
		printf ("%s, '%s'\n", msg1, msg2);
	} else {
		printf ("%s\n", msg1);
	}
	throw severity;	// Raise an exception which will be caught in main().
}

static void push (value_t v) {
	if (StackDepth < MaxStackDepth) {
		stack[StackDepth] = v;
		++StackDepth;
	} else {
		error (skip_e, "stack overflow");
	}
}

static value_t pop (void) {
	if (StackDepth < 1) error (skip_e, "stack underflow");
	--StackDepth;
	return stack[StackDepth];
}

// Look up a symbol.
static value_t GetValue (token_t name, bool AllowFunctions) {
	int	i;
	
	for (i = 0 ; i < NumSymbols && !equal (name, symbols[i].name) ; ++i)
		/* Scan through the symbol table. */ ;
		if (i == NumSymbols) error (skip_e, "symbol not found", name);
		if (!AllowFunctions && i < NumFunctions) error (skip_e, "not a constant", name);
		return symbols[i].value;
}

// Lexical analysis. -----------------------------------------------------------

static void GetNextToken (void) {
	int	TokenLen;
	
	// Skip leading whitespace.
	while ((StmtIndex < StmtLen) && (strchr (" \t", statement[StmtIndex]) != 0) ){
		++StmtIndex;
	}
	if (StmtIndex >= StmtLen) {
		error (skip_e, "No more tokens to get");
	}
	token[0] = statement[StmtIndex];
	TokenLen = 1;
	++StmtIndex;
	
	TokenType = UnknownTT_e;
	if (isalpha (token[0])) {
		TokenType = identifier_e;
		// Token is an identifier.
		// Read in the remainder of the symbol name.
		while (	StmtIndex < StmtLen					// Are there any more characters?
				&& TokenLen < MaxTokenLen			// Have we room for more characters?
				&& isalnum (statement[StmtIndex])	// Do we want them?
		) {
			token[TokenLen] = statement[StmtIndex];
			++TokenLen;
			++StmtIndex;
		}
		if (StmtIndex < StmtLen					// Is another character available?
			&& TokenLen == MaxTokenLen				// Have we run out of room for it?
			&& isalnum (statement[StmtIndex])		// Did we want it?
		) {
			token[TokenLen-1] = 0;	// Terminate the partial token so that we can print it,
			error (skip_e, "Identifier too long for token buffer", token);
		}
		GotToken = true;
	} else if (isdigit (token[0])) {
		TokenType = number_e;
		// Token is a number - a 'literal numerical constant'.
		--StmtIndex;	// Back up a character and read the whole number in.
		if (sscanf (&statement[StmtIndex], "%lf%n", &TokenValue, &TokenLen) == 1) {
			StmtIndex = StmtIndex + TokenLen;
		} else {
			error (skip_e, "program error - couldn't read number", &statement[StmtIndex-1]);
		}
		GotToken = true;
	} else {
		// Hopefully an operator of some kind.
		if (token[0] != '\n') GotToken = true;
	}
	token[TokenLen] = 0;	// Terminate token string.
}

static void InitParser (char *line) {
	// Initialise lexical bits.
	StmtLen = strlen (line);
	if (StmtLen > MaxStmtLen) error (skip_e, "Line too long.");
	strcpy (statement, line);
	StmtIndex = 0;
	
	// Initialise parser/evaluation bits.
	StackDepth = 0;
	GotToken = false;
	
	GetNextToken ();
}

// Syntax analysis. ------------------------------------------------------------

static void ParseExpression (void);	// Forward declaration.

static int ParseArgList (void) {
	int	NumArgs = 0;
	while (!equal (")", token)) {
		NumArgs = NumArgs + 1;
		ParseExpression ();			// and leave the result on the stack.
		if (equal (",", token)) {
			GetNextToken ();
		}
	}
	if (!equal (")", token)) {
		error (skip_e, "Unexpected token, comma expected", token);
	}
	return NumArgs;
}

static void ParsePrimary (void) {
	token_t	name;
	
	switch (TokenType) {
		case number_e:
			push (TokenValue);
			GetNextToken ();
			break;
			
		case identifier_e:
			strcpy (name, token);
			
			GetNextToken ();
			if (equal (token, "(")) {	// Is it actually a function name?
				GetNextToken ();
				if (ParseArgList () == (int)(GetValue(name, true) + 0.1)) { // Correct number of arguments?
					// if (errno != 0) ; // Doing another function would clear the error code.
					// else
					if      (equal (name, "rand"))	push (rand());
					else if (equal (name, "abs"))	push (fabs(pop()));
					else if (equal (name, "sin"))	push (sin(pop()));	// Messy, but necessary given
					else if (equal (name, "cos"))	push (cos(pop()));	// the simplistic way we have
					else if (equal (name, "tan"))	push (tan(pop()));	// declared the functions.
					else if (equal (name, "asin"))	push (asin(pop()));
					else if (equal (name, "acos"))	push (acos(pop()));
					else if (equal (name, "atan"))	push (atan(pop()));
					else if (equal (name, "atan2")) {
						value_t	y = pop();
						value_t	x = pop();
						
						push (atan2(x, y));
					} else {
						error (skip_e, "unrecognised function", name);
					}
				} else {
					error (skip_e, "Incorrect number of arguments given for", name);
				}
				GetNextToken ();	// Discard ")".
			} else {
				push (GetValue (name, false));
			}
			break;
			
		default:
			if (equal ("\n", token)) {
				error (skip_e, "unexpected token", "\\n");
			} else {
				error (skip_e, "unexpected token", token);
			}
			break;
	}
}

static void ParseFactor (void) {
	int	unary_minus = false;
	
	if (equal ("+", token)) {
		// Just ignore unary pluses and they'll go away.
		GetNextToken ();
	} else if (equal ("-", token)) {
		// We have to do something about unary minuses though.
		unary_minus = true;
		GetNextToken ();
	}
	
	if (equal ("(", token)) {
		GetNextToken ();
		ParseExpression ();
		GetNextToken ();	// Discard ")"
	} else {
		ParsePrimary ();
	}
	
	if (unary_minus) {
		push (-pop()); // Negate the top value on the stack.
	}
}

static void ParseTerm (void) {
	ParseFactor ();
	while (equal ("*", token) || equal ("/", token)) {
		token_t	op;
		
		strcpy (op, token);
		GetNextToken ();
		ParseFactor ();
		
		value_t	b = pop (),
				a = pop ();

		if (equal ("*", op)) {
			push (a * b);
		} else {
			push (a / b);
		}
	}
}

static void ParseExpression (void) {
	ParseTerm ();
	while (equal ("+", token) || equal ("-", token)) {
		token_t	op;
		
		strcpy (op, token);
		GetNextToken ();
		ParseTerm ();
		
		value_t	b = pop (),
				a = pop ();

		if (equal ("+", op)) {
			push (a + b);
		} else {
			push (a - b);
		}
	}
}


int SigCode, SubCode;	// Global variables for passing back error info from handler() to user.
void (*OldArithErrHandler)(int); // Pointer to old handler.  Restore after evaluation.
void ArithErrHandler (int sig, int sub) {
	// Could choose to ignore FPE_UNDERFLOW here.
	SigCode = sig;
	SubCode = sub;
}

extern bool ParseProglet (char *line, value_t *value) {
	bool	success = true;
	
	SigCode = 0;	// To catch errors that raise signals.
	SubCode = 0;
	errno = 0;		// To catch errors that don't raise signals.
	OldArithErrHandler = signal (SIGFPE, (void (*) (int))ArithErrHandler);
	
	try {
		InitParser (line);
		
		ParseExpression ();
		if (!equal ("\n", token)) error (skip_e, "Unexpected character at end-of-line", token);
		
		if (GotToken) {
			// If we actually got something there should
			// now be exactly one item on the stack.
			if (StackDepth > 1) error (skip_e, "Final stack too deep");
			if (StackDepth < 1) error (skip_e, "No final value on stack");
			*value = pop ();
		} else {
			// It must have been a blank line.
		}
	}
	catch (severity_t severity) {
		success = false;
		//	switch (severity) {
		//		case skip_e:	break;
		//		case stop_e:	break;
		//		default:		break;
		//	}
	};
	
	// One of these ought to catch most arithmetic errors.
	if (errno != 0) {
		success = false;
		perror ("eval");
	}
	if (SigCode != 0) {
		success = false;
		switch (SubCode) {
			// BC++Builder allows these error conditions to be distinguished.  VC++ doesn't.
			//case FPE_INTOVFLOW:	printf ("Interrupt on overflow.\n");			break;	// int
			//case FPE_INTDIV0:		printf ("Integer divide by zero.\n");			break;	// int
			//case FPE_INVALID:		printf ("Invalid operation.\n");				break;
			//case FPE_ZERODIVIDE:	printf ("Floating point divide by zero.\n");	break;
			//case FPE_OVERFLOW:	printf ("Numeric overflow.\n");					break;
			//case FPE_UNDERFLOW:	printf ("Numeric underflow.\n");				break;
			//case FPE_INEXACT:		printf ("Precision error.\n");					break;
			//case FPE_EXPLICITGEN:	printf ("Explicit SIGFPE error.\n");			break;
			//case FPE_STACKFAULT:	printf ("Floating point stack overflow.\n");	break;
		case SIGFPE:	printf ("Floating point error.\n"); break;
		default:		printf ("Unrecognised signal code %d.\n", SigCode); break;
		}
	}
	signal (SIGFPE, OldArithErrHandler);	// Restore initial error handler.
	
	return success;
}
