(11/4新增) AST 子树结构 DP

一个朋友看面经看到的,说是 VMWare 某次 on campus interview 的问题,我觉得挺有意思,就补充到这个分类里了~

给一个字符串,字符串是一个表达式.

比如: T & F ^ F | T & T ^ T。

T , F 分别表示 true OR false, 操作符有 and, or, xor。 给你任意一个表达式,你可以任意加括号组合,问一个有多少组合的方法使这个表达式能得到 True。

比如 T ^ F & T 可以写成 (T ^ F) & T 或者 T ^ (F & T) 所以一共是两种。

这个问题比较 simple 而 naive 的切入方式,就是先从建立 abstract syntax tree (AST) 开始,像我们做 reverse polish notation 求值一样,根据 operator 和 operand 来建立树形结构。

在这里我们不同的加括号方式,导致了不同的运算顺序,最终导致了不同的 AST 子树结构。如果我们对这个 AST 去做 in-order traversal,会发现所有解的 in-order 结果都和不加括号的原始输入一样。

画出来的话,第一个就是

      &
     / \
    ^   T
   / \
  T   F

第二个是

      ^
     / \
    T   &
       / \
      T   F

那么这个问题在结构上最简单粗暴的定义,就是 -- 枚举出所有可能的 AST 树,并且计算出到底有多少个树可以 evalute true. 暴力解就有了。

下一步是,寻找可能的 overlapping subproblem,和 recurrence relation

常见的做法就是 "加/减一个元素",比如说刚才的输入是 T ^ F & T ,加一个 F & ,变成 F & T ^ F & T.

这个时候立刻可以看出,一种划分方式是 F & (T ^ F & T),括号里的部分就是刚才计算的问题。因此可以发现,这个问题虽然每一个 AST 是 disjoint 的树形结构,但是树与树之间,是存在 overlapping subproblem 的;体现在 AST 上,就是 identical subtree; 体现在输入 String 上,就是 identical subarray.

因此,这一定是一个可以靠动态规划优化的问题。

考虑到我们的操作符有三种,"AND","OR" 和 "XOR",这题和其他的问题还稍微有点不一样:AND 操作符在两边找 True,OR 操作符在两边找到 True 就行,XOR 要找一边 True 一边 False. 所以这题的子问题划分好了之后,我们需要同时记录这个 subproblem 里,到底有多少种 evaluate true/false 的解。(对于固定子树 size N,这其实是一个常数)

对于给定的 String input,这题的 top-down 解决步骤是:

  • 依次枚举每一个操作符的位置作为 root,同时以该操作符为分界线,把输入分成左右两个子问题;

  • 根据左右子问题分别的 true/false 数量,还有当前操作符的类型,计算出当前 subtree 结构的 true/false 数量;

  • 一个 subproblem 的 true/false 数量,是其所有以操作符为 root 的划分结果的和;

  • 所有子问题都可以 top-down with memoization,用记忆化搜索,字符串做 key,true/false 解的数量做 value.

  • bottom-up 的迭代方式也可以,考虑到子串的连续性,可以直接用 int[i][j] 的方式来定义,从第 i 到第 j 个 index 上的子问题里,有多少个 true/false 就行了。

一个记忆化搜索的例子,同一种做法还有优化空间,比如用区间 int start, int end 来定义一个子问题。这种写法比较省事,不过毕竟还是生成了很多个 substring 的 copy,还有 string 做 key 在 hashmap 中查找的开销。

  • 每个子问题对应一个 tuple,其 T/F 各自一个 count;

  • 对于左右两个子问题,总共可能的组合数为 leftTotal * rightTotal;

  • 当操作符为 '&' 时,当前位置的 trueCnt = leftTrue * rightTrue;

  • 当操作符为 '|' 时,当前位置的 falseCnt = leftFalse * rightFalse;

  • 当操作符为 '^' 时,当前位置的 trueCnt = leftTrue * rightFalse + leftFalse * rightTrue;

  • 在每个操作符上,trueCnt + falseCnt = totalCnt;

public class Solution {
    private static class Tuple{
        int trueCnt;
        int falseCnt;
        public Tuple(int trueCnt, int falseCnt){
            this.trueCnt = trueCnt;
            this.falseCnt = falseCnt;
        }
    }

    private static Tuple evaluate(String str, HashMap<String, Tuple> map){
        // Assuming input string is valid;
        // Thus all boolean is at EVEN index; all operator at ODD index.

        // base case
        if(str.length() == 1){
            return (str.equals("T")) ? new Tuple(1,0)
                                     : new Tuple(0,1);
        }

        if(map.containsKey(str)){
            System.out.println("Cache hit for substring : " + str);
            return map.get(str);
        }

        Tuple tuple = new Tuple(0,0);

        for(int i = 1; i < str.length(); i += 2){
            Tuple left = evaluate(str.substring(0, i), map);
            Tuple right = evaluate(str.substring(i + 1), map);

            char operator = str.charAt(i);
            int curTotal = (left.falseCnt + left.trueCnt) *
                           (right.falseCnt + right.trueCnt);

            int curTrue = 0;
            int curFalse = 0;

            if(operator == '&'){
                curTrue = left.trueCnt * right.trueCnt;
                curFalse = curTotal - curTrue;
            } else if(operator == '|'){
                curFalse = left.falseCnt * right.falseCnt;
                curTrue = curTotal - curFalse;
            } else if(operator == '^'){
                curTrue = left.trueCnt * right.falseCnt +
                              left.falseCnt * right.trueCnt;
                curFalse = curTotal - curTrue;
            }

            tuple.trueCnt += curTrue;
            tuple.falseCnt += curFalse;
        }

        map.put(str, tuple);
        return tuple;
    }

    public static void main(String[] args) {
        String str1 = "T&F^T";
        String str2 = "T&F^T|F";
        String str3 = "T&F&T|T^F";
        String str4 = "T&F^T^T&F&T|T^F";

        HashMap<String, Tuple> map = new HashMap<>();
        Tuple tuple = evaluate(str3, map);

        System.out.println(tuple.trueCnt);
    }
}

先说一个自己的做法,写的是搜索,对于每一个 input string ,我们扫描并建立两个 list; 一个是 operands,一个是 operators. 每次 dfs 的时候挑一对数字和一个操作符,计算,修改 list 做 dfs,返回之后再把 list 改回来。

过了 15/25 个 test case 之后卡在了 "2*3-4*5" 上面,原因是多了一个额外的 -14,因为两个 subproblem 加括号的方式是一样的,顺序不同而已。

所以暴力搜索的问题在于,如果前面和后面加括号的方式是一样的,需要额外手段来判断 “重复” 。

而DFS 中想 cache 一个 List<>,远远不如 String 方便。

如果实在追求暴力到底,就干脆对每一个状态都搞序列化,记录当前的 operands & operators ,遇到重复的状态就直接返回了。

因此为了避免 dfs + backtracking 可能遇到的 overlap subproblem 返回多个结果的问题,可以先直接 divide & conquer,因为这个搜索结构是树状的,天生的 disjoint.

根据这个思想,我们可以以“操作符”为分隔,借鉴编译器和 reverse polish notation 中的 "expression tree" 来进行计算,结构如下:

这样左右子树都进行的完美的分隔,而且应为 input 为 string ,也非常容易对子问题进行记忆化搜索。

Divide & Conquer, 8ms,超过 41.50%.

public class Solution {
    public List<Integer> diffWaysToCompute(String input) {

        List<Integer> rst = new ArrayList<>();

        for(int i = 0; i < input.length(); i++){
            if(!Character.isDigit(input.charAt(i))){
                char operator = input.charAt(i);

                List<Integer> left = diffWaysToCompute(input.substring(0, i));
                List<Integer> right = diffWaysToCompute(input.substring(i + 1));

                for(int num1 : left){
                    for(int num2 : right){
                        if(operator == '+') rst.add(num1 + num2);
                        if(operator == '-') rst.add(num1 - num2);
                        if(operator == '*') rst.add(num1 * num2);
                    }
                }
            }
        }

        if(rst.size() == 0) rst.add(Integer.parseInt(input));

        return rst;
    }
}

带记忆化搜索,3ms,超过 91.61%

public class Solution {
    public List<Integer> diffWaysToCompute(String input) {
        return helper(input, new HashMap<String, List<Integer>>());
    }

    private List<Integer> helper(String str, HashMap<String, List<Integer>> map){
        if(map.containsKey(str)) return map.get(str);

        List<Integer> list = new ArrayList<>();

        for(int i = 0; i < str.length(); i++){
            char chr = str.charAt(i);
            if(!Character.isDigit(chr)){
                List<Integer> leftList = helper(str.substring(0, i), map);
                List<Integer> rightList = helper(str.substring(i + 1), map);

                for(int leftNum : leftList){
                    for(int rightNum : rightList){
                        if(chr == '+') list.add(leftNum + rightNum);
                        if(chr == '-') list.add(leftNum - rightNum);
                        if(chr == '*') list.add(leftNum * rightNum);
                    }
                }
            }
        }

        if(list.size() == 0) list.add(Integer.parseInt(str));

        map.put(str, list);
        return list;
    }
}

Last updated