Given a binary tree, collect a tree's nodes as if you were doing this: Collect and remove all leaves, repeat until the tree is empty.

Example:
Given binary tree 
          1
         / \
        2   3
       / \     
      4   5    
Returns [4, 5, 3], [2], [1].

Explanation:
1. Removing the leaves [4, 5, 3] would result in this tree:

          1
         / 
        2          
2. Now removing the leaf [2] would result in this tree:

          1          
3. Now removing the leaf [1] would result in the empty tree:

          []         
Returns [4, 5, 3], [2], [1].

Better Solution: https://discuss.leetcode.com/topic/49194/10-lines-simple-java-solution-using-recursion-with-explanation/2

For this question we need to take bottom-up approach. The key is to find the height of each node. The height of a node is the number of edges from the node to the deepest leaf.

 1 /**
 2  * Definition for a binary tree node.
 3  * public class TreeNode {
 4  *     int val;
 5  *     TreeNode left;
 6  *     TreeNode right;
 7  *     TreeNode(int x) { val = x; }
 8  * }
 9  */
10 public class Solution {
11     public List<List<Integer>> findLeaves(TreeNode root) {
12         List<List<Integer>> res = new ArrayList<>();
13         helper(root, res);
14         return res;
15     }
16     
17     public int helper(TreeNode cur, List<List<Integer>> res) {
18         if (cur == null) return -1;
19         int level = 1 + Math.max(helper(cur.left, res), helper(cur.right, res));
20         if (res.size() <= level) 
21             res.add(new ArrayList<Integer>());
22         res.get(level).add(cur.val);
23         cur.left = cur.right = null;
24         return level;
25     }
26 }

 

First time solution: HashSet+ DFS

 1 /**
 2  * Definition for a binary tree node.
 3  * public class TreeNode {
 4  *     int val;
 5  *     TreeNode left;
 6  *     TreeNode right;
 7  *     TreeNode(int x) { val = x; }
 8  * }
 9  */
10 public class Solution {
11     public List<List<Integer>> findLeaves(TreeNode root) {
12         ArrayList<List<Integer>> res = new ArrayList<List<Integer>>();
13         if (root == null) return res;
14         HashSet<TreeNode> visited = new HashSet<>();
15         while (!visited.contains(root)) {
16             ArrayList<Integer> leaves = new ArrayList<Integer>();
17             helper(root, leaves, visited);
18             res.add(new ArrayList<Integer>(leaves));
19         }
20         return res;
21     }
22     
23     public void helper(TreeNode cur, ArrayList<Integer> leaves, HashSet<TreeNode> visited) {
24         if ((cur.left==null || visited.contains(cur.left)) && (cur.right==null || visited.contains(cur.right))) {
25             leaves.add(cur.val);
26             visited.add(cur);
27             return;
28         }
29         if (cur.left!=null && !visited.contains(cur.left))
30             helper(cur.left, leaves, visited);
31         if (cur.right!=null && !visited.contains(cur.right))
32             helper(cur.right, leaves, visited);
33     }
34 }