Binary Tree Maximum Path Sum
Given a binary tree, find the maximum path sum.The path may start and end at any node in the tree.
For example:
Given the below binary tree,
1 / \ 2 3
Return
6
.Hide Tags
Tree Depth-first Search
Hide Similar Problems
(E) Path Sum (M) Sum Root to Leaf Numbers
难倒不难的一道题,不过还是有可以琢磨一下的地方。
稍微要注意的地方是递归调用的时候更新的最大值是父左右子的最大和,
但返回给上一层的是父加上较大的子(或者不加)的和,这两个东东是不一样的。
其次就是代码的组织问题了,先写了一个solution,因为习惯在往下走的时候先判断,而不是掉进去了才判断,
所以在叶子节点的时候,直接返回了它的值,从而在子节点返回之后要判断 vv, ll, rr, vv+ll, vv+rr, vv+ll+rr 六种情形的最大值,
返回的时候还要依据 ll, rr 是否为负数来决定向上返回的时候要不要加上 ll rr,所以代码看起来比较凌乱。
/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode(int x) : val(x), left(NULL), right(NULL) {}
* };
*/class Solution {
public:
int maxPathSum(TreeNode* root) {
int rst=INT_MIN;
if(root)
dfs(root, rst);
return rst;
}
int dfs(TreeNode* node, int& rst)
{
int ll=0, rr=0, vv=node->val, maxsum=vv;
if(node->left) {
ll = dfs(node->left, rst);
maxsum = mxx(vv, ll, vv+ll);
}
if(node->right) {
rr = dfs(node->right, rst);
maxsum = max(maxsum, mxx(rr, vv+rr, vv+ll+rr)); }
rst = max(rst, maxsum);
return max(vv + ((0<ll)?ll:0), vv + ((0<rr)?rr:0));
}
int mxx(int a, int b, int c)
{
return (a>=b) ? ((c>=a)?c:a) : ((c>=b)?c:b);
}
};
32 ms.
去讨论区看了看,有个java 的代码写得比较简洁:
他对叶子节点也是调进去了,因为不再有子节点,所以左右分别返回 0, 这样也就分别取到了 ll rr 的值并更新了最大值。
然后在上一层,因为负数对于构成更大的和没有意义,所以直接对 0 求了最大值才保存到临时变量,这样后面就简单了。
Elegant Java solution
public class Solution {
int max = Integer.MIN_VALUE;
public int maxPathSum(TreeNode root) {
helper(root);
return max;
}
// helper returns the max branch
// plus current node's value
int helper(TreeNode root) {
if (root == null) return 0;
int left = Math.max(helper(root.left), 0);
int right = Math.max(helper(root.right), 0);
max = Math.max(max, root.val + left + right);
return root.val + Math.max(left, right); }
}
于是又回头改了一下自己的代码,还是习惯先保护后调用,不过也精简了不少,也不再需要 mxx 函数了:
class Solution {
public:
int maxPathSum(TreeNode* root) {
int rst=INT_MIN;
if(root)
dfs(root, rst);
return rst;
}
int dfs(TreeNode* node, int& rst)
{
int ll=0, rr=0, vv=node->val;
if(node->left)
ll = max(0, dfs(node->left, rst));
if(node->right)
rr = max(0, dfs(node->right, rst));
rst = max(rst, vv+ll+rr);
return vv+max(ll,rr);
}
};
32 ms.
No comments:
Post a Comment