Given an array w of positive integers, where w[i] describes the weight of index i, write a functionpickIndex which randomly picks an index in proportion to its weight.

Note:

  1. 1 <= w.length <= 10000
  2. 1 <= w[i] <= 10^5
  3. pickIndex will be called at most 10000 times.

Example 1:

Input: 
["Solution","pickIndex"]
[[[1]],[]]
Output: [null,0]

Example 2:

Input: 
["Solution","pickIndex","pickIndex","pickIndex","pickIndex","pickIndex"]
[[[1,3]],[],[],[],[],[]]
Output: [null,0,1,1,1,0]

Explanation of Input Syntax:

The input is two lists: the subroutines called and their arguments. Solution‘s constructor has one argument, the array wpickIndex has no arguments. Arguments are always wrapped with a list, even if there aren’t any.

[code lang="java"]
class Solution {

    public Solution(int[] w) {
        
    }
    
    public int pickIndex() {
        
    }
}

/**
 * Your Solution object will be instantiated and called as such:
 * Solution obj = new Solution(w);
 * int param_1 = obj.pickIndex();
 */
[code]

Idea – 1

Consider the sequence \langle 1, 3, 2, 4 \rangle, here at indices 0, 1, 2, and 3 we have weights 1, 3, 2, and 4. Thus 10% of the time we must pick index 0, 30% of the time we must pick index 1, 20% of the time we must pick index 2, and 40% of the time we must pick index 3. For our example totalWeight = 1+3+2+4 = 10. We divide the total weight like: [1, 1], [2, 4], [5, 6], and [7, 10]. Now if we draw x from [1..10] at random, we pick 0 if x is in [1, 1], we pick 1 if x is in [2, 4] etc. So we pick an index proportionally to its weight. Time complexity is O(n^2), space complexity is O(1).
[code lang="java"]
class Solution {
    
    private static final Random rng = new Random(System.currentTimeMillis()%Integer.MAX_VALUE);
    private int totalWeight = 0;
    private int[] weights;
    
    public Solution(int[] w) {
        for(int x : w)
        {
            totalWeight += x;
        }

        weights = w;
    }
    
    public int pickIndex() {
        
        int x = 1+rng.nextInt(totalWeight);
        
        int cumsum = 0;
        for(int i = 0; i < weights.length; ++i)
        {
            cumsum += weights[i];
            if(x <= cumsum)
            {
                return i;
            }
        }
        
        return -1;
    }
}
[code]

Runtime: 90 ms, faster than 13.88% of Java online submissions for Random Pick with Weight.Memory Usage: 52 MB, less than 35.80% of Java online submissions for Random Pick with Weight.

Idea – 2

If we precompute the cumsum, we can dispense with the summing in pickIndex. Time complexity remains same, space complexity becomes O(n).
[code lang="java"]
class Solution {
    
    private static final Random rng = new Random(System.currentTimeMillis()%Integer.MAX_VALUE);
    private int[] cumsum;
    
    public Solution(int[] w) {
        
        cumsum = new int[w.length];
        cumsum[0] = w[0];
        for(int i = 1; i < w.length; ++i)
        {
            cumsum[i] = cumsum[i-1] + w[i];
        }
    }
    
    public int pickIndex() {
        
        int x = 1+rng.nextInt(cumsum[cumsum.length-1]);
        
        for(int i = 0; i < cumsum.length; ++i)
        {
            if(x <= cumsum[i])
            {
                return i;
            }
        }
        
        return -1;
    }
}
[code]

Runtime: 80 ms, faster than 27.95% of Java online submissions for Random Pick with Weight.Memory Usage: 52.5 MB, less than 34.07% of Java online submissions for Random Pick with Weight.

Idea – 3

Since cumsum is sorted, we could leverage binary search in pickIndex which reduces time complexity to O(n\cdot \lg{n}), space complexity remains O(n).
[code lang="java"]
class Solution {
    
    private static final Random rng = new Random(System.currentTimeMillis()%Integer.MAX_VALUE);
    private int[] cumsum;
    
    public Solution(int[] w) {
        
        cumsum = new int[w.length];
        cumsum[0] = w[0];
        for(int i = 1; i < w.length; ++i)
        {
            cumsum[i] = cumsum[i-1] + w[i];
        }
    }
    
    public int pickIndex() {
        
        int x = 1+rng.nextInt(cumsum[cumsum.length-1]);
        
        return ceilSearch(cumsum, x);
    }
    
    private int ceilSearch(int[] A, int key)
    {
        int lo = 0, hi = A.length-1;
        while(lo <= hi)
        {
            int mid = lo+(hi-lo)/2;
            if(key == A[mid])
            {
                return mid;
            }
            else if(key < A[mid])
            {
                hi = mid-1;
            }
            else
            {
                lo = mid+1;
            }
        }
        
        return lo; 
    }
}
[code]

Runtime: 67 ms, faster than 92.00% of Java online submissions for Random Pick with Weight.Memory Usage: 44.3 MB, less than 95.91% of Java online submissions for Random Pick with Weight.

Leave a comment