Skip to content

3592. Find X Sum Of All K Long Subarrays Ii

Difficulty: Hard

LeetCode Problem View on GitHub


3592. Find X-Sum of All K-Long Subarrays II

Hard


You are given an array nums of n integers and two integers k and x.

The x-sum of an array is calculated by the following procedure:

  • Count the occurrences of all elements in the array.
  • Keep only the occurrences of the top x most frequent elements. If two elements have the same number of occurrences, the element with the bigger value is considered more frequent.
  • Calculate the sum of the resulting array.

Note that if an array has less than x distinct elements, its x-sum is the sum of the array.

Return an integer array answer of length n - k + 1 where answer[i] is the x-sum of the subarray nums[i..i + k - 1].

 

Example 1:

Input: nums = [1,1,2,2,3,4,2,3], k = 6, x = 2

Output: [6,10,12]

Explanation:

  • For subarray [1, 1, 2, 2, 3, 4], only elements 1 and 2 will be kept in the resulting array. Hence, answer[0] = 1 + 1 + 2 + 2.
  • For subarray [1, 2, 2, 3, 4, 2], only elements 2 and 4 will be kept in the resulting array. Hence, answer[1] = 2 + 2 + 2 + 4. Note that 4 is kept in the array since it is bigger than 3 and 1 which occur the same number of times.
  • For subarray [2, 2, 3, 4, 2, 3], only elements 2 and 3 are kept in the resulting array. Hence, answer[2] = 2 + 2 + 2 + 3 + 3.

Example 2:

Input: nums = [3,8,7,8,7,5], k = 2, x = 2

Output: [11,15,15,15,12]

Explanation:

Since k == x, answer[i] is equal to the sum of the subarray nums[i..i + k - 1].

 

Constraints:

  • nums.length == n
  • 1 <= n <= 105
  • 1 <= nums[i] <= 109
  • 1 <= x <= k <= nums.length

Solution

class Solution {
    static class Pair {
        int node, freq;
        public Pair(int node, int freq) {
            this.node = node;
            this.freq = freq;
        }
        @Override
        public String toString() {
            return "(" + node + " " + freq + ")";
        }
        @Override
        public int hashCode() {
            return Objects.hash(node, freq);
        }
        @Override
        public boolean equals(Object obj) {
            if (this == obj) return true;
            if (obj == null || getClass() != obj.getClass()) return false;
            Pair current = (Pair)(obj);
            return current.node == node && current.freq == freq;
        }
    }
    static class custom_sort implements Comparator<Pair> {
        @Override
        public int compare(Pair first, Pair second) {
            int op1 = Integer.compare(second.freq, first.freq);
            if (op1 != 0) return op1;
            return Integer.compare(second.node, first.node);
        }
    }
    public static long[] findXSum(int[] arr, int k, int x) {
        int n = arr.length, p = 0;
        long res[] = new long[n - k + 1]; 
        TreeSet<Pair> set = new TreeSet<>(new custom_sort());
        HashMap<Integer, Integer> map = new HashMap<>(); 
        TreeSet<Pair> removed = new TreeSet<>(new custom_sort());
        long sum = 0;
        for (int i = 0; i < k; i++) {
            int current = arr[i];
            if (map.containsKey(current)) {
                if(set.contains(new Pair(current, map.getOrDefault(current, 0)))) sum -= current * 1L * map.getOrDefault(current, 0);
                if (removed.contains(new Pair(current, map.getOrDefault(current, 0)))) removed.remove(new Pair(current, map.getOrDefault(current, 0)));
                set.remove(new Pair(current, map.getOrDefault(current, 0)));

                map.put(current, map.getOrDefault(current, 0) + 1);
                sum += current * 1L * map.getOrDefault(current, 0);
                set.add(new Pair(current, map.getOrDefault(current, 0)));
                if (set.size() > x) {
                    Pair last = set.pollLast();
                    sum -= last.node * 1L * last.freq;
                    removed.add(new Pair(last.node, map.getOrDefault(last.node, 0)));
                }
            }
            else {
                map.put(current, 1);
                set.add(new Pair(current, 1));
                sum += current;
                if (set.size() > x) {
                    Pair last = set.pollLast();
                    sum -= last.node * 1L * last.freq;
                    removed.add(new Pair(last.node, map.getOrDefault(last.node, 0)));
                }
            }
            while (removed.size() > 0) {
                Pair second = removed.first();
                Pair first = set.last();
                if (second.freq > first.freq || (second.freq == first.freq && second.node > first.node)) {
                    sum += second.node * 1L *  second.freq;
                    set.add(removed.pollFirst());
                    if (set.size() > x) {
                        Pair r = set.last();
                        removed.add(set.pollLast());
                        sum -= r.node * 1L * r.freq;
                    }
                }
                else break;
            }
        }
        res[p++] = sum;
        int start = 0;
        for (int i = k; i < n; i++) {
            int prev = arr[start++];
            if(set.contains(new Pair(prev, map.getOrDefault(prev, 0)))) sum -= prev * 1L * map.getOrDefault(prev, 0);
            if (removed.contains(new Pair(prev, map.getOrDefault(prev, 0)))) removed.remove(new Pair(prev, map.getOrDefault(prev, 0)));
            set.remove(new Pair(prev, map.getOrDefault(prev, 0)));

            map.put(prev, map.getOrDefault(prev, 0) -1);
            sum += prev * 1L * map.getOrDefault(prev, 0);
            set.add(new Pair(prev, map.getOrDefault(prev, 0)));

            if (set.size() > x) {
                Pair last = set.pollLast();
                sum -= last.node * 1L * last.freq;
                removed.add(new Pair(last.node, map.getOrDefault(last.node, 0)));
            }
            int now = arr[i];
            if(set.contains(new Pair(now, map.getOrDefault(now, 0)))) sum -= now * 1L * map.getOrDefault(now, 0);
            if (removed.contains(new Pair(now, map.getOrDefault(now, 0)))) removed.remove(new Pair(now, map.getOrDefault(now, 0)));
            set.remove(new Pair(now, map.getOrDefault(now, 0)));

            map.put(now, map.getOrDefault(now, 0) + 1);
            sum += now * 1L * map.getOrDefault(now, 0);
            set.add(new Pair(now, map.getOrDefault(now, 0)));

            if (set.size() > x) {
                Pair last = set.pollLast();
                sum -= last.node * 1L * last.freq;
                removed.add(new Pair(last.node, map.getOrDefault(last.node, 0)));
            }
            while (removed.size() > 0) {
                Pair second = removed.first();
                Pair first = set.last();
                if (second.freq > first.freq || (second.freq == first.freq && second.node > first.node)) {
                    sum += second.node * 1L * map.getOrDefault(second.node, 0);
                    set.add(removed.pollFirst());
                    if (set.size() > x) {
                        Pair r = set.last();
                        removed.add(set.pollLast());
                        sum -= r.node * 1L * r.freq;
                    }
                }
                else break;
            }
            res[p++] = sum;
        }
        return res;
    }
}

Complexity Analysis

  • Time Complexity: O(?)
  • Space Complexity: O(?)

Approach

Detailed explanation of the approach will be added here