一道Google top coder的850分例题及解答


 


原题:


 


假设有这样一种字符串,它们的长度不大于 26 ,而且若一个这样的字符串其长度为 m ,则这个字符串必定由 a, b, c ... z 中的前 m 个字母构成,同时我们保证每个字母出现且仅出现一次。比方说某个字符串长度为 5 ,那么它一定是由 a, b, c, d, e 这 5 个字母构成,不会多一个也不会少一个。嗯嗯,这样一来,一旦长度确定,这个字符串中有哪些字母也就确定了,唯一的区别就是这些字母的前后顺序而已。


 


现在我们用一个由大写字母 A 和 B 构成的序列来描述这类字符串里各个字母的前后顺序:


 


l         如果字母 b 在字母 a 的后面,那么序列的第一个字母就是 A (After),否则序列的第一个字母就是 B (Before);


l         如果字母 c 在字母 b 的后面,那么序列的第二个字母就是 A ,否则就是 B;


l         如果字母 d 在字母 c 的后面,那么 …… 不用多说了吧?直到这个字符串的结束。


 


这规则甚是简单,不过有个问题就是同一个 AB 序列,可能有多个字符串都与之相符,比方说序列"ABA",就有"acdb"、"cadb"等等好几种可能性。说的专业一点,这一个序列实际上对应了一个字符串集合。那么现在问题来了:给你一个这样的AB 序列,问你究竟有多少个不同的字符串能够与之相符?或者说这个序列对应的字符串集合有多大?注意,只要求个数,不要求枚举所有的字符串。


 


注:如果结果大于10亿就返回-1。


 


 


我的最终解答(没有考虑溢出的情况):


 


// CODE 1


// the best way


// O(N^2)


int countABbest(const string& AB)


{


    assert(AB.find_first_not_of("AB") == string::npos);


 


    vector<int> current, next; // should we reserve these vectors?


    current.push_back(1);


 


    for (string::const_iterator letter = AB.begin();


        letter != AB.end(); ++letter) {


        next.resize(current.size()+1); // or next.insert(next.end(), 2, 0);


        next[0] = 0; // in fact, we could set the entire vector to zero


 


        if (*letter == 'A') {


            partial_sum(current.begin(), current.end(), next.begin()+1);


        } else {


            partial_sum(current.rbegin(), current.rend(), next.begin()+1);


            reverse(next.begin(), next.end());


        }


        swap(current, next);


    }


    return accumulate(current.begin(), current.end(), 0);


}


 


int main()


{


    const char* AB = "ABBAAB";


    printf("'%s' : %d\n", AB, countABbest(AB));


}


 


下面谈一谈我在解决这个问题时的思路。


 


第一步 初步分析


 


以下“字符串”特指题目中提到的由小写字母a、b、c等等组成的字符串,每个字母出现且仅出现一次。显然题目要求我们写一个函数f,f的输入是一个长度为v 的AB序列w,w代表了一个字符串集合s(集合中的元素都是长度为m(m=v+1)的字符串),f的返回值是这个集合的元素个数|s|,即|s|=f(w)。用高中学过的一点排列组合知识,可分析知:


1.       长度为m的字符串有m! 个(’!’ 表示阶乘)因为这相当于m个不同字母的全排列;


2.       长度为v的AB序列有2^v个(’^’ 表示指数)因为每个位置有2种可能,一共有v个位置;


3.       由于2^v <= m! (m=v+1),所以AB序列的数目不大于字符串的数目。


4.       每个字符串刚好有一个AB序列与之对应。比如对于字符串”abdec”,我们很容易得知b在a后,c在b后,d在c前,e在d后,因此它对应的AB序列为”AABA”。可见拿到一个字符串,立刻就能求出它对应的那一个AB序列。


5.       每个AB序列至少对应一个字符串(当然也对应多个,因为字符串数目远远大于AB序列数目)。比如任取一个AB序列”ABA”,很容易构造出与它对应的字符串:


                                  i.    b在a后,得”ab”;


                                ii.    c在b前,得”acb”或”cab”;


                              iii.    d在c后,拿”acb”来说,可得”acdb”和”acbd”;拿”cab”来说,可得 ”cdab”、”cadb”和”cabd”;这样一共构造了5个与”ABA”对应的字符串,而且不会再有别的字符串了(why?)。


其实我们已经找到了蛮力解决问题的办法。


6.       根据4、5,得知如果穷举出长度为v的AB序列(共2^v个),并计算每个序列对应的字符串数目,那么把所有这些数目加起来,应该等于(v+1)!,这可以用作我们算法的一个检验。


7.       其实这可以看作集合的划分,把一个有 m! 个元素的集合U划分为2^v个不相交的子集s_0, s_1, s_{2^v–1},每个子集s_i是一个类别,每个字符串都属于一个类别,问题转变为求给定类别中有多少个元素。


 


第二步 蛮力解决


 


在想到前面的分析之前,我先用一种蛮力办法部分地解决了这个问题,思路是拿到一个长度为v的AB序列,穷举所有长度为v+1的字符串,遇到匹配的就记录下来。这样得到第一个程序,这个程序虽然效率极低,但可以用来检验后面程序的正确性,是个标竿。


 


// CODE 2


bool match(const string& AB, const string& str)


{


    // many ways to improve this function, but we won’t bother it.


    for (size_t i = 0; i < AB.length(); ++i) {


        size_t first = str.find('a'+i);


        size_t second = str.find('a'+i+1);


        assert(first != string::npos && second != string::npos);


 


        if (AB[i] == 'A' && first > second) {


            return false;


        } else if (AB[i] == 'B' && first < second) {


            return false;


        }


    }


    return true;


}


 


// the stupid way


// O(N! * N^2)


int countAB(const string& AB)


{


    assert(AB.find_first_not_of("AB") == string::npos);


 


    string str;


    int count = 0;


    int m = (int)AB.length() + 1;


   


    // construct the initial string


    for (int i = 0; i < m; ++i) {


        str.push_back('a'+i);


    }


   


    do {


        if (match(AB, str)) {


            printf("%s, ", str.c_str());


            count++;


        }


    } while (next_permutation(str.begin(), str.end()));


   


    return count;


}


 


上面这个程序是以AB序列为中心,想办法找到与它匹配的字符串。为了看它能否通过第6点分析的检验,我写了一个enumAB(int v)函数,用来穷举长度为v的所有AB序列,并做检验(检验基本靠眼)。


 


// CODE 3


void enumAB(int v)


{


    assert(0 <= v && v < 26);


    int nAB = 1 << v;


    int total = 0;


 


    for (int i = 0; i < nAB; ++i) {


        string AB;


        for (int bit = v-1; bit >= 0; --bit) {


            if (i & (1 << bit)) {


                AB.push_back('B');


            } else {


                AB.push_back('A');


            }


        }


        int count = countAB(AB);


        total += count;


        printf("%s : %d\n", AB.c_str(), count);


    }


 


    printf("\nTotal strings: %d\n", total);


}


 


以下是enumAB(4)的运行结果(5!=120,初步检验通过):


 


AAAA : 1


AAAB : 4


AABA : 9


AABB : 6


ABAA : 9


ABAB : 16


ABBA : 11


ABBB : 4


BAAA : 4


BAAB : 11


BABA : 16


BABB : 9


BBAA : 6


BBAB : 9


BBBA : 4


BBBB : 1


 


Total strings: 120


 


如果想穷举所有AB序列和它们对应的字符串,还可以用一种效率稍高的蛮力算法,以字符串为中心,穷举所有长度为m的字符串,把它归入相应的AB序列名下。代码如下。


 


// CODE 4


string getAB(const string& str)


{


    const char* alphabet = "abcdefghijklmnopqrstuvwxyz";


    assert(str.find_first_not_of(alphabet, 0, str.length()) == string::npos);


 


    int pos[26] = {0};


    char AB[26] = {0};


 


    int m = (int)str.length();


   


    for (int i = 0; i < m; ++i) {


        pos[str[i]-'a'] = i;


    }


 


    for (int i = 0; i < m-1; ++i) {


        AB[i] = pos[i] < pos[i+1] ? 'A' : 'B';


    }


 


    return AB; // we are not return the local char array, but a string object.


}


 


void enumStr(int m)


{


    string str;


    int nAB = 0;


   


    for (int i = 0; i < m; ++i) {


        str.push_back(char('a'+i));


    }


   


    map<string, vector<string> > AB2strs;


   


    do {


        string AB = getAB(str);


        //printf("%s is of %s\n", str.c_str(), AB.c_str());


        AB2strs[AB].push_back(str);


    } while (next_permutation(str.begin(), str.end()));


   


 


    for (map<string, vector<string> >::iterator it = AB2strs.begin();


        it != AB2strs.end(); ++it) {


            ++nAB;


            printf("%s (%d): ", it->first.c_str(), it->second.size());


            for (vector<string>::iterator str = it->second.begin();


                str != it->second.end(); ++str) {


                    printf("%s, ", str->c_str());


            }


            printf("\n");


    }


    printf("\nTotal ABs : %d\n", nAB);


}


 


以下是enumStr(4)的运行结果(2^3=8,初步检验通过):


 


AAA (1): abcd,


AAB (3): abdc, adbc, dabc,


ABA (5): acbd, acdb, cabd, cadb, cdab,


ABB (3): adcb, dacb, dcab,


BAA (3): bacd, bcad, bcda,


BAB (5): badc, bdac, bdca, dbac, dbca,


BBA (3): cbad, cbda, cdba,


BBB (1): dcba,


 


Total ABs : 8


 


第三步 进阶分析


 


我们也可以根据前面第5点分析,做出一个更高效的蛮力算法,不过蛮力毕竟是蛮力,还是让我们动动脑筋,做个真正高效的算法吧。


我第一次拿到这个问题时,先用蛮力算法打印出前面的结果,试图分析其规律,没成功。便又在纸上演算了了一阵,发现其实可以递推解决(当然也可以递归解决),以下内容最好在纸上演算。比如对于序列”AAA”,字母d只可能在第3号位置出现一次(abcd);递推一下,对于序列”AAAB”,e在d前,那么e可以在第0、1、2、3号位置各出现一次(eabcd、aebcd、abecd、abced)。


又比如根据以前面第5点分析,如果我们知道对于序列”AB”,字母c可能在第0号位置出现一次(cab)、在第1号位置出现一次(acb);那么对于序列”ABA”,字母d会在第1、2、3号位置分别出现1、2、2次,因此”ABA”对应的字符串共有5个;同理对于序列”ABB”,字母d会在第0、1号位置分别出现2、1次,因此”ABB”对应的字符串共有3个。


继续递推,对于序列”ABBA”,e在d后,那么e可以在第1、2、3、4号位置分别出现2、3、3、3次(具体说来,对于d在第0号位置出现2次,那么e可以在第1、2、3、4号位置各出现2次;d在第1号位置出现1次,那么e可以在第2、3、4号位置各出现1次,对位加起来就得到前面“2、3、3、3”的结果),因此”ABBA”对应的字符串共有11个。


到这里,我们已经发现递推的规律了:对于AB序列w,用二维数组occurs[][]表示第letter个字母在位置pos出现的次数occurs[letter][pos](这个说法不太严格,应该说是w的前面长度为letter的子序列对应的字符串中,最大那个字母出现的位置和次数,呵呵,还是比较绕口)。如果字母p在位置q1出现n1次,而AB序列的当前元素为’A’,那么字母p+1会在位置q1+1, q1+2, . . . , p各出现n1次;如果AB序列的当前元素为’B’,那么字母p+1会在位置0, 1, . . . , q1各出现n1次;如果字母p还在q2位置出现了n2次,那么对于’A’ 情况,字母p+1还会在位置q2+1, q2+2, . . . , p各出现n2次;那么对于’B’ 情况,字母p+1还会在位置0, 1, . . . , q2各出现n2次。需要把这些情况都累加起来。


对于序列”ABBAA”,递推表如下:


1, 0, 0, 0, 0, 0       字母a在位置0出现1次


0, 1, 0, 0, 0, 0       字母b在位置1出现1次


1, 1, 0, 0, 0, 0       字母c在位置0、1分别出现1次


2, 1, 0, 0, 0, 0       字母d在位置0、1分别出现2、1次


0, 2, 3, 3, 3, 0       字母e在位置1、2、3、4分别出现2、3、3、3次


0, 0, 2, 5, 8, 11     字母f 在位置2、3、4、5分别出现2、5、8、11次


可知对应的字符串有26个。如果细心,已经能发现递推中的部分和(partial sum)关系。


 


第四步 解决


 


既然递推关系有了,很容易就能写出代码。这个算法的复杂度是O(N^3)。


 


// CODE 5


// the better way


// O(N^3)


int countABbetter(const string& AB)


{


    assert(AB.find_first_not_of("AB") == string::npos);


 


    int v = (int)AB.length();


    int m = v + 1;


 


    // 'letter' at 'pos' occurs 'occurs[letter][pos]' times.


    vector<vector<int> > occurs(m, vector<int>(m, 0));


 


    // letter 'a' at pos 0, 1 time


    occurs[0][0] = 1;


 


    for (int letter = 1; letter < m; ++letter) {


        for (int pos = 0; pos < letter; ++pos) {


            int first_pos = 0;


            int last_pos = 0;


 


            if (AB[letter-1] == 'A') {


                // after current pos


                first_pos = pos + 1;


                last_pos = letter;


            } else {


                assert(AB[letter-1] == 'B');


                // before (and at) current pos


                first_pos = 0;


                last_pos = pos;


            }


 


            int occur = occurs[letter-1][pos];


            for (int t = first_pos; t <= last_pos; ++t) {


                occurs[letter][t] += occur;


                assert(occurs[letter][t] >= 0);


            }


        }


    }


    return accumulate(occurs[m-1].begin(), occurs[m-1].end(), 0);


}


 


第五步 优化


 


前面提过一句,在递推的过程中其实隐藏了一个“部分和”的关系,利用这一性质,可以很容易地将复杂度降为O(N^2),而且递推只是根据当前字母的出现位置退出下一字母的出现位置,因此可以省去2维数组,改用两个vector就行了。最后的代码就是前面一开始列出的 CODE 1。


 


第六步 展望


 


我猜测算法的复杂度能进一步降到 O(N log N),不过自己已经没有能力实现了。另外,为了附庸风雅一把,我发现整个递推算法的过程如果用矩阵来描述,会变得相当清楚。比如对于序列”ABAAB”,很容易构造矩阵A1、B2、A3、A4、B5(每个矩阵都是6阶方阵),初始向量x=[1 0 0 0 0 0]T,生成向量y=B5*A4*A3*B2*A1*x,那么对应的字符串有sum(y)个(sum表示y的各分量之和)。


注:也可以定义初始向量x=[1],矩阵A1是2x1、矩阵B2是3x2、矩阵A3是4x3、……、矩阵B5是6x5,一样可以计算出向量y。


例如:(这些矩阵中的元素都是0或1,排列起来像三角形(因为是求部分和),很有规律的。)


A1 = [0; 1]


B2 = [1 1; 0 1; 0 0]


A3 = [0 0 0; 1 0 0; 1 1 0; 1 1 1]


A4 = [0 0 0 0; 1 0 0 0; 1 1 0 0; 1 1 1 0; 1 1 1 1]


B5 = [1 1 1 1 1; 0 1 1 1 1; 0 0 1 1 1; 0 0 0 1 1; 0 0 0 0 1; 0 0 0 0 0]


算出y = B5*A4*A3*B2*A1 = [9 9 9 8 5 0] T


sum(y) = 40,与前面程序的结果相同。


 


. 完 .