网上看了很多人的答案,都说最优解是o(n),想到了一种更快的算法,复杂度是O(lgn的平方),就是对左右子树进行二分,找最后一层的最右边那个结点即可:
#include <iostream> #include <cmath> #include <stdlib.h> using namespace std; struct Node { Node* left; Node* right; ~Node() { if (left) { delete left; left = NULL; } if (right) { delete right; right = NULL; } } }; int GetDepth(Node* root) { int depth = 0; while (root) { depth++; root = root->left; } return depth; } void GetMostRightDepthAndIndex(Node* root, int* depth, int* index) { while (root) { if (root->right) { root = root->right; *index = (*index - 1) * 2; *index += 2; *depth += 1; } else if (root->left) { root = root->left; *index = (*index - 1) * 2; *index += 1; *depth += 1; } else { return; } } } void GetMostLeftDepthAndIndex(Node* root, int* depth, int* index) { while (root) { if (root->left) { root = root->left; *index = (*index - 1) * 2; *index += 1; *depth += 1; } else if (root->left) { root = root->right; *index = (*index - 1) * 2; *index += 2; *depth += 1; } else { return; } } } void GetNodeNum(Node* root, int cur_depth, int cur_index, int* node_num, int max_depth) { if (!root->left) { if (cur_depth - 1 >= 1) { *node_num += pow(2, cur_depth - 1) - 1; } *node_num += cur_index; return; } int ml_depth = root->left ? cur_depth + 1 : cur_depth; int ml_index = (cur_index - 1) * 2 + 1; int mr_depth = root->right ? cur_depth + 1 : cur_depth; int mr_index = (cur_index - 1) * 2 + 2; GetMostRightDepthAndIndex(root->left, &ml_depth, &ml_index); GetMostLeftDepthAndIndex(root->right, &mr_depth, &mr_index); if (ml_depth == mr_depth) { if (ml_depth == max_depth) { GetNodeNum(root->right, cur_depth + 1, (cur_index - 1) * 2 + 2, node_num, max_depth); } else if (ml_depth < max_depth) { GetNodeNum(root->left, cur_depth + 1, (cur_index - 1) * 2 + 1, node_num, max_depth); } else { std::cout << "illegal tree"; exit(1); } } else if (ml_depth > mr_depth) { if (ml_depth == max_depth && mr_depth == max_depth - 1) { *node_num = pow(2, ml_depth - 1) - 1 + ml_index; } } else { std::cout << "illegal tree"; exit(1); } } int main() { /* Input: create a */ // depth 1 Node* root = new Node(); // depth 2 root->left = new Node(); root->right = new Node(); // depth 3 root->left->left = new Node(); root->left->right = new Node(); root->right->left = new Node(); root->right->right = new Node(); // depth 4 root->left->left->left = new Node(); root->left->left->right= new Node(); root->left->right->left = new Node(); root->left->right->right = new Node(); root->right->left->left = new Node(); root->right->left->right = new Node(); root->right->right->left = new Node(); root->right->right->right = new Node(); // depth 5 root->left->left->left->left = new Node(); int root_depth = 1; int root_index = 1; int max_depth = GetDepth(root); int node_num = 0; GetNodeNum(root, root_depth, root_index, &node_num, max_depth); std::cout << "node_num = " << node_num << endl; delete root; }