Skip to content

3437. Maximum Total Damage With Spell Casting

Difficulty: Medium

LeetCode Problem View on GitHub


3437. Maximum Total Damage With Spell Casting

Medium


A magician has various spells.

You are given an array power, where each element represents the damage of a spell. Multiple spells can have the same damage value.

It is a known fact that if a magician decides to cast a spell with a damage of power[i], they cannot cast any spell with a damage of power[i] - 2, power[i] - 1, power[i] + 1, or power[i] + 2.

Each spell can be cast only once.

Return the maximum possible total damage that a magician can cast.

 

Example 1:

Input: power = [1,1,3,4]

Output: 6

Explanation:

The maximum possible damage of 6 is produced by casting spells 0, 1, 3 with damage 1, 1, 4.

Example 2:

Input: power = [7,1,6,6]

Output: 13

Explanation:

The maximum possible damage of 13 is produced by casting spells 1, 2, 3 with damage 1, 6, 6.

 

Constraints:

  • 1 <= power.length <= 105
  • 1 <= power[i] <= 109

Solution

class Solution {
    private long dp[];
    private HashMap<Integer, Integer> freq;
    public long maximumTotalDamage(int[] arr) {
        int n = arr.length;
        dp = new long[n + 1];
        Arrays.fill(dp, -1L);
        freq = new HashMap<>();
        for (int ele : arr) 
            freq.put(ele, freq.getOrDefault(ele, 0) + 1);
        ArrayList<Integer> nums = new ArrayList<>();
        for (Map.Entry<Integer, Integer> curr : freq.entrySet()) 
            nums.add(curr.getKey());
        Collections.sort(nums);
        return solve(0, nums); 
    }
    private long solve(int ind, ArrayList<Integer> nums) {
        if (ind >= nums.size()) 
            return 0L;
        if (dp[ind] != -1) 
            return dp[ind] * 1L;
        long op1 = solve(ind + 1, nums);
        int nextIdx = bs(ind + 1, nums.get(ind), nums);
        long op2 = nums.get(ind)  * 1L *  freq.get(nums.get(ind)) + solve(nextIdx, nums);
        return dp[ind] = Math.max(op1, op2);
    }
    private int bs(int start, int target, ArrayList<Integer> arr) {
        int n = arr.size();
        int low = start, high = arr.size() - 1, ans = arr.size();
        while (low <= high) {
            int mid = low + (high - low) / 2;
            if (arr.get(mid) > target + 2) {
                ans = mid;
                high = mid - 1;
            } else 
                low = mid + 1;
        } 
        return ans;
    }
}

Complexity Analysis

  • Time Complexity: O(?)
  • Space Complexity: O(?)

Approach

Detailed explanation of the approach will be added here