目次
やりたいこと
- 全てのNodeの合計を算出 (=total)
- 各Nodeで、SubTreeの合計を算出 (=sub_sum)
- 2の全要素に対して、anser = max(anser, sub_sum * (total – sub_sum)); を行う
- anser % 109 + 7を返す
Hintを見ながら、こんな感じかな?となりました。
コストを無視するなら……
最初に考えたのはパワープレーでぶん回す方法です。
class Solution {
private:
static constexpr int mod = 1e9 + 7;
public:
int maxProduct(TreeNode* root) {
long anser = 0;
forEach(root, [this, &anser, total = sum(root)] (const TreeNode& node) mutable {
const long sub_sum = sum(&node);
anser = std::max(anser, sub_sum * (total - sub_sum));
});
return anser % mod;
}
private:
long sum(const TreeNode* root) {
long total = 0;
forEach(root, [&total] (const TreeNode& node) mutable {
total += node.val;
});
return total;
}
void forEach(const TreeNode* root, const std::function<void(const TreeNode&)>& action) {
if (root == nullptr) {
return;
}
action(*root);
forEach(root->left, action);
forEach(root->right, action);
}
};
このコードは保守性/可読性の観点で見たら、ある程度まとまっているかなと思うのですが、forEachが入れ子になっているのでNodeが増えていくごとに計算量が倍々で増えてしまいます。
LeetCodeで流してみると、TimeLimitになります……
たどり着いた実装
経緯は忘れましたが、「やりたいこと」の 1. と 2. を同時に行うという結論に至りました。
class Solution {
private:
static constexpr int mod = 1e9 + 7;
public:
int maxProduct(TreeNode* root) {
long anser = 0;
const auto [total, sums] = treeSums(root);
for (long sum : move(sums)) {
anser = std::max(anser, sum * (total - sum));
}
return anser % mod;
}
private:
/// @return first: total, second: sums of each node
std::pair<long, std::list<long>> treeSums(const TreeNode* root) {
if (root == nullptr) {
return { 0, {} };
}
auto [total, sums] = treeSums(root->left);
auto [right_total, right_sums] = treeSums(root->right);
sums.merge(move(right_sums));
total += right_total + root->val;
sums.emplace_back(total);
return { total, move(sums) };
}
};
全ての合計値と、各Nodeでの合計のリストをくれる、treeSums()を用意しました。
これで全てのNodeを一回舐めれば良いだけになりました。
以上です!