Problem 2255 6/2(1+2) ICPC模擬国内予選2011 D

例について考えてみる

6 / 2 * ( 1 + 2 )
> 6 / ( 2 * ( 1 + 2 ) )
> ( 6 / 2 ) * ( 1 + 2 )
6 / 2 * ( 1 + 2 * 3 )
> 6 / ( 2 * ( ( 1 + 2 ) * 3 ) )
> 6 / ( 2 * ( 1 + ( 2 * 3 ) ) )
> ( 6 / 2 ) * ( ( 1 + 2 ) * 3 )
> ( 6 / 2 ) * ( 1 + ( 2 * 3 ) )

解き方

私が必死にパースして構文木作って、組み合わせ生成してってやったらTLEした。
次を参考にして解いてみた。2011年度ICPC模擬国内予選


解き方の方針としては2つに分割して再帰して計算結果の組み合わせを求める。
計算結果は複数与えられるので、setを用いる。(重複しないところも)

set<int> solve(int a,int b) // こんな感じの関数
  set<int> lhs = solve(a,i-1); // 入力文字[a..a+x-1]
  set<int> rhs = solve(i+1,b); // 入力文字[a+x+1..b]

  FOREACH(set<int>,lhs,ll)
  FOREACH(set<int>,rhs,rr)
    // 演算

ソースコード

#include <iostream>
#include <string>
#include <set>
#include <cstdlib>
#include <algorithm>
using namespace std;

#define FOREACH(t,p,it) for(t::iterator it=p.begin();it!=p.end();++it)
#define foreach(t,p) FOREACH(t,p)
#define all(p) p.begin(),p.end()
#define REP(i,p) for(int i=0;i<p;i++)
#define rep(p) REP(i,p)
string s;

set<int> solve(int a,int b)
{
    bool found = false; // 62とか(1+2)の場合はフラグが落ちたまま
    int nest=0;
    set<int> ans;

    for(int i=a;i<b;i++)
    {
        if( s[i] == '(' ) nest++; // ()の中の演算子はスルー
        if( s[i] == ')' ) nest--; // ()の外の演算子は注目

        if(!nest&&(s[i]=='+'||s[i]=='-'||s[i]=='*'||s[i]=='/'))
        {
            found = true; // 数字だけ()だけのパターンではない

            set<int> lhs = solve(a,i-1); // 左側の式の計算結果の集合
            set<int> rhs = solve(i+1,b); // 右側の式の計算結果の集合

            FOREACH(set<int>,lhs,ll) 
            FOREACH(set<int>,rhs,rr)
            {
                if( s[i]=='+' ) ans.insert( (*ll) + (*rr) );
                if( s[i]=='-' ) ans.insert( (*ll) - (*rr) );
                if( s[i]=='*' ) ans.insert( (*ll) * (*rr) );
                if( s[i]=='/' && (*rr)!=0 ) ans.insert( (*ll) / (*rr) );
            }
        }
    }

    if(!found)
    {
        if(s[a]=='(') return solve(a+1,b-1); // (1+2)とか(1+2/3)とか
        ans.insert( atoi( s.substr(a,b-a+1).c_str() ) ); // 1243とか5とか
    }

    return ans;
}


int main()
{
    while(cin>>s&&s!="#") cout << solve(0,s.size()-1).size() << endl;
    return 0;
}