最近在拿LLVM写玩具语言, 实现词法分析器和语法分析器后想做做测试, 这道题刚好可以帮忙测二元表达式的求值。
    文法规则题目中给了, '^'右递归, ‘±×/%’左递归, 一元的‘-’可以在前面补个0方便处理, 括号递归定义即可。 简单写成这个样子:

bin_expr := primary (binop primary)*
binop := '+' | '-' | '*' | '/' | '^' | '%'
primary := (digit)+
:= '(' bin_expr ')'
:= '-' primary

    接下来按照算数优先级递归+迭代,递归主体包含三个关键的优先级变量, 分别是整个lhs的前置运算符优先级old_prec, 当前二元运算符优先级prec与下一个二元运算符优先级next_prec. 通常情况下若prec<=old_prec则将lhs返回, 若prec<nexc_prec则以prec和当前的rhs为参数开启新一轮递归, 每次迭代结束将lhs和rhs合并为新的lhs进行下一轮迭代,考虑’^'时只需要加几个特判就好。具体可以参考代码#186.
    错误处理可以直接抛异常,然后在main里catch一下输出就好。另外这题的越界检查不是那么严格,求幂次时可以用pow糊,还有循环求幂的话会T…

//bin_expr := primary (binop primary)*
//binop := '+' | '-' | '*' | '/' | '^' | '%'
//primary := (digit)+
// := '(' bin_expr ')'
// := '-' primary

#include <bits/stdc++.h>

using namespace std;

bool DEBUG = false;
int depth = 0;
std::string blank() {
std::string res = "";
for (int i = 0; i < depth; i++) res += "--";
return res;
}

enum Token {
EOF_TOKEN = 128,
OP_TOKEN,
NUM_TOKEN
};

class Lexer {
private:
std::string m_code;
int m_idx;
int m_token;
long long m_num;
std::string m_op;
std::map<std::string, int> m_op_prec;

int m_nextToken() {
if (m_idx >= m_code.size()) return EOF_TOKEN;
if (isdigit(m_code[m_idx])) {
m_num = 0;
do {
m_num = m_num*10 + m_code[m_idx]-'0';
if (m_num > INT_MAX) {
throw "lexer.m_nextToken: overflow";
}
m_idx++;
} while (m_idx < m_code.size() && isdigit(m_code[m_idx]));
return NUM_TOKEN;
} else if (m_op_prec.count(std::string(1, m_code[m_idx])) != 0) {
m_op = m_code[m_idx];
m_idx++;
return OP_TOKEN;
} else {
if (m_code[m_idx] != '(' && m_code[m_idx] != ')') {
throw "lexer.m_nextToken: unknown character";
}
char temp = m_code[m_idx];
m_idx++;
return temp;
}
}
public:
Lexer(const std::string &code) : m_idx(0), m_token(0), m_num(0), m_op("") {
m_code = "";
for (int i = 0; i < code.size(); i++) {
if (!isspace(code[i])) m_code += code[i];
}
int paran = 0;
for (int i = 0; i < m_code.size(); i++) {
if (m_code[i] == '(') paran++;
else if (m_code[i] == ')') {
if (paran <= 0) throw "paran matching error";
paran--;
}
}
if (paran != 0) throw "paran maching error";
m_op_prec["^"] = 3;
m_op_prec["%"] = 2;
m_op_prec["/"] = 2;
m_op_prec["*"] = 2;
m_op_prec["-"] = 1;
m_op_prec["+"] = 1;
}
~Lexer() {}
int getToken() { return m_token; }
long long getNum() { assert(m_token == NUM_TOKEN); return m_num; }
std::string& getOp() { assert(m_token == OP_TOKEN); return m_op; }
int getOpPrec() { assert(m_token == OP_TOKEN); return m_op_prec[m_op]; }
int nextToken() { return (m_token = m_nextToken()); }
};

class TopAST {
public:
TopAST() {}
virtual ~TopAST() {}
virtual long long calculate() { return 0; };
virtual std::string toString() { return ""; };
};

class BinAST : public TopAST {
private:
std::string m_op;
BinAST *m_lhs, *m_rhs;
public:
BinAST() : TopAST() { m_op = ""; m_lhs = m_rhs = nullptr; }
BinAST(const std::string& op, BinAST *lhs, BinAST *rhs) : TopAST(), m_op(op), m_lhs(lhs), m_rhs(rhs) {}
virtual ~BinAST() {
if (m_lhs != nullptr) delete m_lhs;
if (m_rhs != nullptr) delete m_rhs;
}
virtual long long calculate() {
assert(m_lhs != nullptr && m_rhs != nullptr);
if (DEBUG) {
cout << blank() + "@" + m_op << endl;
depth++;
}
long long lv = m_lhs->calculate();
long long rv = m_rhs->calculate();
if (DEBUG) {
depth--;
}
long long ans = 0;
if (m_op == "+") {
ans = lv + rv;
} else if (m_op == "-") {
ans = lv - rv;
} else if (m_op == "*") {
ans = lv * rv;
} else if (m_op == "/") {
if (rv == 0) {
throw "binast.calculate: /0";
}
ans = lv / rv;
} else if (m_op == "%") {
if (rv == 0) {
throw "binast.calculate: %0";
}
ans = lv % rv;
} else if (m_op == "^") {
if (rv < 0) {
throw "binast.calculate: x^(-x)";
}
if (lv == 0 && rv == 0) {
throw "binast.calculate: 0^0";
}
double temp = pow((double)lv, (double)rv);
if (temp < INT_MIN || ans > INT_MAX) throw "binast.calculate: overflow";
ans = (long long)temp;
} else {
throw "binast.calculate: unknown binop";
}
if (ans < INT_MIN || ans > INT_MAX) throw "binast.calculate: overflow";
return ans;
}
virtual std::string toString() {
assert(m_lhs != nullptr && m_rhs != nullptr);
return m_lhs->toString() + m_op + m_rhs->toString();
}
};

class Int32AST : public BinAST {
private:
long long m_val;
public:
Int32AST(long long val) : BinAST(), m_val(val) {}
virtual ~Int32AST() {}
virtual long long calculate() {
return m_val;
}
virtual std::string toString() {
return std::to_string(m_val);
}
};

class Parser {
private:
Lexer *m_lex;
public:
Parser(const std::string &code) { m_lex = new Lexer(code); }
~Parser() { delete m_lex; }
TopAST *driver() {
m_lex->nextToken();
TopAST *ast = binParser(0, primaryParser());
if (m_lex->getToken() != EOF_TOKEN) {
throw "parser.driver: missing eof";
}
return ast;
}
BinAST *binParser(int old_prec, BinAST* lhs) {
while (true) {
if (m_lex->getToken() != OP_TOKEN) {
if (m_lex->getToken() != ')' && m_lex->getToken() != EOF_TOKEN) {
throw "parser.binParser: unknown binop1";
}
return lhs;
}
std::string op = m_lex->getOp();
int prec = m_lex->getOpPrec();
if ((op != "^" && prec <= old_prec)) return lhs;
if (DEBUG) {
cout << blank() + "\'" + op + "\'" << endl;
}
m_lex->nextToken();
BinAST *rhs = primaryParser();
if (m_lex->getToken() != OP_TOKEN) {
if (m_lex->getToken() != ')' && m_lex->getToken() != EOF_TOKEN) {
throw "parser.binParser: unknown binop2";
}
BinAST *ast = new BinAST(op, lhs, rhs);
return ast;
} else {
if ((m_lex->getOp() != "^" && prec < m_lex->getOpPrec()) ||
(m_lex->getOp() == "^" && prec <= m_lex->getOpPrec())) {
if (DEBUG) {
depth++;
}
rhs = binParser(prec, rhs);
if (DEBUG) {
depth--;
}
}
lhs = new BinAST(op, lhs, rhs);
}
}
}
BinAST *primaryParser() {
BinAST *ast = nullptr;
if (m_lex->getToken() == OP_TOKEN) {
std::string op = m_lex->getOp();
if (op != "-") {
throw "parser.unParser: unknown unop";
}
m_lex->nextToken();
if (DEBUG) {
cout << blank() + "un-" << endl;
}
ast = new BinAST(op, new Int32AST(0), primaryParser());
} else if (m_lex->getToken() == '(') {
m_lex->nextToken();
if (DEBUG) {
cout << blank() + "(" << endl; depth++;
}
ast = binParser(0, primaryParser());
if (m_lex->getToken() != ')') {
throw "parser.primaryParser: missing )";
}
if (DEBUG) {
depth--; cout << blank() + ")" << endl;
}
m_lex->nextToken();
} else if (m_lex->getToken() == NUM_TOKEN) {
ast = int32Parser();
if (DEBUG) {
cout << blank() + ast->toString() << endl;
}
} else {
throw "parser.primaryParser: unknown character";
}
return ast;
}
Int32AST *int32Parser() {
Int32AST *ast = new Int32AST(m_lex->getNum());
m_lex->nextToken();
return ast;
}
};

void createTestCase() {
srand((unsigned int)time(0));
for (int i = 0; i < 500; i++) {
int num = 0;
string s = "";
s += std::to_string(rand()%20);
for (int j = 0; j < 20; j++) {
double q = 1.0*rand()/RAND_MAX;
if (q <= 0.025) s += ")";
else if (q <= 0.05) s += "-";
double p = 1.0*rand()/RAND_MAX;
if (p <= 0.1) {
s += "^";
} else if (p <= 0.2) {
s += "%";
} else if (p <= 0.4) {
s += "*";
} else if (p <= 0.6) {
s += "/";
} else if (p <= 0.8) {
s += "-";
} else {
s += "+";
}
p = 1.0*rand()/RAND_MAX;
if (p <= 0.5) {
s += "(";
num++;
}
p = 1.0*rand()/RAND_MAX;
if (p <= 0.5) {
s += "-";
}
s += std::to_string(rand()%20);
p = 1.0*rand()/RAND_MAX;
if (p <= 0.5 && num != 0) {
num--;
s += ")";
}
}
while (num--) s += ")";
cout << s << endl;
}
exit(0);
}

int main() {
// createTestCase();
int t;
cin >> t; getchar();
string str;
for (int i = 1; i <= t; i++) {
try {
getline(cin, str);
cout << "Case " + std::to_string(i) + ": ";
Parser parser(str);
unique_ptr<TopAST> ast(parser.driver());
if (DEBUG) {
cout << "========================================" << endl;
}
cout << ast->calculate() << endl;
} catch (const char* s) {
cout << "ERROR!" << endl;
if (DEBUG) {
cout << s << endl;
}
}
}
}