Sparse vector  Multiplication


https://github.com/tongzhang1994/Facebook-Interview-Coding/blob/master/Sparce%20Matrix%20Multiplication.java



public class Solution {//assume inputs are like {{2, 4}, {0, 10}, {3, 15}},index0 is index of non-zero vals,index1 is the val
    private Comparator<ArrayList<Integer>> sparseVectorComparator = new Comparator<ArrayList<Integer>>(){
        public int compare(ArrayList<Integer> a, ArrayList<Integer> b) {
            return a.get(0) - b.get(0);
        }
    };//remember to add ";" !!!
    
    public int sparseVectorMultiplication(ArrayList<ArrayList<Integer>> a, ArrayList<ArrayList<Integer>> b) {
        if (a == null || b == null || a.size() == 0 || b.size() == 0) {
            return 0;
        }
        int m = a.size();
        int n = b.size();
        int res = 0;
        
        //two inputs are unsorted, directly iterate the elements(brute force); O(m*n) time; if use sort, O(mlogm + nlogn)
        for (int i = 0; i < m; i++) {
            ArrayList<Integer> pairA = a.get(i);
            for (int j = 0; j < n; j++) {
                ArrayList<Integer> pairB = b.get(i);
                if (pairA.get(0) == pairB.get(0)) {//if their indices are the same, calculate and break
                    res += pairA.get(1) * pairB.get(1);
                    break;//pairA has been calculated, jump to next pair
                }
            }
        }
        
        //if we need to sort the inputs
        Collections.sort(a, sparseVectorComparator);
        
        //two inputs are sorted by index0, use two pointers(move the smaller, calculate the equal); O(m+n) time
        int i = 0;
        int j = 0;
        while (i < m && j < n) {
            ArrayList<Integer> pairA = a.get(i);
            ArrayList<Integer> pairB = b.get(j);
            if (pairA.get(0) < pairB.get(0)) {
                i++;
            } else if (pairA.get(0) > pairB.get(0)) {
                j++;
            } else {
                res += pairA.get(1) * pairB.get(1);
                i++;
                j++;
            }
        }
        
        //two inputs are sorted by index0, have same size, sometimes dense, sometimes sparse; two pointes + binary search
        int i = 0;
        int j = 0;
        int countA = 0;
        int countB = 0;
        while (i < m && j < n) {
            ArrayList<Integer> pairA = a.get(i);
            ArrayList<Integer> pairB = b.get(j);
            if (pairA.get(0) < pairB.get(0)) {
                i++;
                countA++;
                countB = 0;
                if (countA > Math.log(m)) {
                    i = search(a, i, m, pairB.get(0));
                    countA = 0;
                }
            } else if (pairA.get(0) > pairB.get(0)) {
                j++;
                countB++;
                countA = 0;
                if (countB > Math.log(n)) {
                    j = search(b, j, n, pairA.get(0));
                    countB = 0;
                }
            } else {
                res += pairA.get(1) * pairB.get(1);
                i++;
                j++;
                countA = 0;
                countB = 0;
            }
        }
        
        //two inputs are sorted by index0, input b is much larger than input a, iterate a and binary search b; O(m*logn) time
        int i = 0;
        int j = 0;
        while (i < m) {
            ArrayList<Integer> pairA = a.get(i++);
            j = search(b, j, n, pairA.get(0));
            ArrayList<Integer> pairB = b.get(j++);
            if (pairA.get(0) == pairB.get(0)) {
                res += pairA.get(1) * pairB.get(1);
            }
        }
        
        return res;
    }
    
    private int search(ArrayList<ArrayList<Integer>> array, int start, int end, int target) {
        while (start + 1 < end) {
            int mid = start + (end - start) / 2;
            ArrayList<Integer> pair = array.get(mid);
            if (pair.get(0) == target) {
                return mid;
            } else if (pair.get(0) < target) {
                start = mid;
            } else {
                end = mid;
            }
        }
        if (array.get(end).get(0) == target) {
            return end;
        }
        return start;
    }
}
面试官先问每个vector很大,不能在内存中存下怎么办,我说只需存下非零元素和他们的下标就行,然后问面试官是否可用预处理后的
这两个vector非零元素的index和value作为输入,面试官同意后写完O(M*N)的代码(输入未排序,只能一个个找),MN分别是两个vector长度。

又问这两个输入如果是根据下标排序好的怎么办,是否可以同时利用两个输入都是排序好这一个特性,最后写出了O(M + N)的双指针方法,
每次移动pair里index0较小的指针,如果相等则进行计算,再移动两个指针。

又问如果一个向量比另一个长很多怎么办,我说可以遍历长度短的那一个,然后用二分搜索的方法在另一个vector中找index相同的那个元素,
相乘加入到结果中,这样的话复杂度就是O(M*logN)。

又问如果两个数组一样长,且一会sparse一会dense怎么办。他说你可以在two pointer的扫描中内置一个切换二分搜索的机制。
看差值我说过,设计个反馈我说过,他说不好。他期待的解答是,two pointers找到下个位置需要m次比较,而直接二分搜需要log(n)次比较。
那么在你用two pointers方法移动log(n)次以后,就可以果断切换成二分搜索模式了。

Binary search如果找到了一个元素index,那就用这次的index作为下次binary search的开始。可以节约掉之前的东西,不用search了。
然后问,如果找不到呢,如何优化。说如果找不到,也返回上次search结束的index,然后下次接着search。
就是上一次找到了,就用这个index继续找这次的;如果找不到,也有一个ending index,就用那个index当starting index。
比如[1, 89,100],去找90;如果不存在,那么binary search的ending index应该是89,所以下次就从那个index开始。
如果找不到,会返回要插入的位置index + 1,index是要插入的位置,我写的就是返回要插入的index的。
但是不管返回89还是100的index都无所谓,反正只差一个,对performance没有明显影响的。