Poj Solution 1100

http://poj.org/problem?id=1100

#include <iostream>
#include <sstream>
#include <string>
#include <vector>
using namespace std;
#define NIL        (unsigned)-1
#define MAX        12
typedef vector <int> VI;
typedef unsigned int UI;

int cas, N, F;
int equ, tmp_num[MAX], UFS[MAX];
VI num, prio, pos, opr, v0, v1;
string expr;
stringstream ss;
char opr_sign[3] =
{ '+', '-', '*' };
int (*opr_f[3])(int a, int b);

int inline add(int a, int b)
{
    return a+b;
}
int inline sub(int a, int b)
{
    return a-b;
}
int inline mul(int a, int b)
{
    return a*b;
}
class Disjoint
{
public:
    int parent[12];
    Disjoint()
    {
        for (int i = 0; i < 12; parent[ i ] = i, i++)
            ;
    }
    int Find(int x)
    {
        if (parent[x] == x)
            return x;
        else
            parent[x] = Find(parent[x]);
        return parent[x];
    }
    void inline Union(int a, int b)
    {
        parent[Find(b)] = Find(a);
    }
};
int inline cast(string s)
{
    int tmp;
    ss.clear();
    ss << s;
    ss >> tmp;
    return tmp;
}
string fix(string s)
{
    int eq, sign;
    eq = s.find('=', 0);
    if (s[eq+1] != ' ')
        s.insert(s.begin()+eq+1, ' ');
    if (s[eq-1] != ' ')
        s.insert(s.begin()+eq, ' ');
    while ((sign = s.find(")(", 0)) != -1)
        s.insert(s.begin()+sign+1, ' ');
    while ((sign = s.find("( ", 0)) != -1)
        s.erase(sign+1, 1);
    while ((sign = s.find(" )", 0)) != -1)
        s.erase(sign, 1);
    s.insert(s.begin()+s.length(), ' ');
    for (UI i = s.find('=', 0); i < s.length(); i++)
    {
        if (s[ i ] == '(')
            if (s[i-1] != ' ' && s[i-1] != '(')
                s.insert(s.begin()+i, ' ');
        if (s[ i ] == ')')
            if (s[i+1] != ' ' && s[i+1] != ')')
                s.insert(s.begin()+i+1, ' '), i++;
    }
    s.erase(s.length()-1);
    for (UI i = 1; i < s.length(); i++)
        if (s[i-1] == ' ' && s[ i ] == ' ')
            s.erase(i, 1), i--;
    return s;
}
int expression(string s)
{
    string tmp;
    int blank, lastblank, p, pp;
    int lp, rp;
    Disjoint ufs;
    num.clear();
    prio.clear();
    pos.clear();
    opr.clear();
    v0.clear();
    v1.clear();
    lp = rp = 0;

    s = fix(s);
    blank = s.find(' ', 0);
    tmp = s.substr(0, blank);
    equ = cast(tmp);
    blank = s.find(' ', blank+1);
    expr = s = s.substr(blank+1, s.length()-blank-1);

    blank = -1;
    do
    {
        lastblank = s.find(' ', blank+1);
        tmp = s.substr(blank+1, (lastblank == -1 ? s.length() : lastblank)
                -blank);
        blank = lastblank;
        pos.push_back(blank);
        while (tmp.find('(', 0) != NIL)
            tmp.erase(0, 1);
        while (tmp.find(')', 0) != NIL)
            tmp.erase(tmp.length()-1, 1);
        num.push_back(cast(tmp));
    } while (blank != -1);

    for (UI i = 0, p = 0; i < s.length(); i++)
    {
        if (s[ i ] == '(')
            p++, lp++;
        else if (s[ i ] == ')')
            p--, rp++;
        else if (s[ i ] == ' ')
            prio.push_back(p);
    }

    while (p != -1)
    {
        p = -1, pp = 0;
        for (UI i = 0; i < prio.size(); i++)
            if (prio[ i ]> p)
                p = prio[ i ], pp = i;
        if (p == -1)
            break;
        for (UI i = pp; i < prio.size(); i++)
            if (prio[ i ] == p)
                v0.push_back(i), prio[ i ] = -1;
            else
                break;
    }
    for (UI i = 0; i < v0.size(); i++)
    {
        v1.push_back(ufs.Find(v0[ i ]));
        v1.push_back(ufs.Find(v0[ i ]+1));
        ufs.Union(v0[ i ], v0[ i ]+1);
    }
    for (UI i = 0; i < prio.size(); i++)
        opr.push_back(0);

    if (lp != rp)
        return 0;

    return num.size();
}

int check()
{
    int a, b;
    for (int i = 0; i < N; i++)
        tmp_num[ i ] = num[ i ];
    for (int i = 0; i < N - 1; i++)
    {
        a = v1[i*2], b = v1[i*2+1];
        tmp_num[a] = (*opr_f[opr[v0[ i ]]])(tmp_num[a], tmp_num[b]);
    }
    return tmp_num[0];
}
void print(string s)
{
    cout << equ <<'=';
    for (int i = 0; i < N - 1; i++)
        s[pos[ i ]] = opr_sign[opr[ i ]];
    cout << s << endl;
}
void dfs(int k)
{
    if (F)
        return;
    if (k == N - 1)
    {
        if (check() == equ)
            print(expr), F = 1;
        return;
    }
    for (int i = 0; i < 3; i++)
    {
        opr[k] = i;
        dfs(k+1);
    }
}

int main()
{
    opr_f[0] = add;
    opr_f[1] = sub;
    opr_f[2] = mul;
    cas = 1;
    while (getline(cin, expr) && expr != "0")
    {
        N = expression(expr);
        printf("Equation #%d:n", cas++);
        if (N == 1)
        {
            if (equ == num[0])
                print(expr);
            else
                cout << "Impossible" << endl;
        }
        else
        {
            F = 0;
            dfs(0);
            if (!F)
                cout << "Impossible" << endl;
        }
        cout << endl;
    }
    return 0;
}
											
This entry was posted in poj. Bookmark the permalink.