Discrete.java


Below is the syntax highlighted version of Discrete.java from §9.8 Data Analysis.


/******************************************************************************
 *  Compilation:  javac Discrete.java
 *  Execution:    java Discrete N
 *
 *  Implementation of algorithm for sampling from a discrete
 *  probability N-vector X[1], X[2], ..., X[N].
 *  Runs nn O(1) worst case time per sample, using one integer 
 *  array and one double array of size N for storage.
 *  Preprocessing consumes O(N) time and temporarily uses one 
 *  additional integer array of size N for bookkeeping. 
 *
 *  This code is a Java version of Warren D. Smith's WDSsampler.c
 *
 *  This method was originally developed by Walker and improved
 *  by Kronmal and Peterson (1979).
 *
 ******************************************************************************/

public class Discrete {
    private double[] Y;
    private int[] A;
    private int N;

    // Walker's sampling algorithm
    // choose i between 1 and N uniformly at random
    // return i with prob Y[i]; otherwise return i
    public int random() {
        // i = random uniform integer from { 1, 2, ..., N }
        int i = 1 + (int) (N * Math.random());
        double r = Math.random();
        if (r > Y[i]) i = A[i];
        return i;
    }

    public Discrete(double[] X) {
        N = X.length - 2;
        Y = new double[N+2];
        for (int i = 1; i <= N; i++) Y[i] = X[i];
        A = new int[N+2];
        int[] B = new int[N+2];   // for bookkeeping
        for(int i = 1; i <= N; i++) { 
            A[i] = B[i] = i;            // initial destins = stay there
            Y[i] = Y[i] * N;            // scale probability vector
        }

        // sentinels
        B[0]   = 0;
        B[N+1] = N+1;
        Y[0]   = 0.0;
        Y[N+1] = 2.0;

        int i = 0; 
        int j = N + 1;
        while(true) {
            do{ i++; } while(Y[B[i]] <  1.0);  // find i so X[B[i]] needs more
            do{ j--; } while(Y[B[j]] >= 1.0);  // find j so X[B[j]] wants less
            if(i >= j) break;
            int k = B[i]; B[i] = B[j]; B[j] = k;   // swap B[i] and B[j]
        }
        i = j;
        j++;
        while(i > 0) {
            while(Y[B[j]] <= 1.0) { j++; }  // find j so X[B[j]] needs more
            if(j > N) break;
            Y[B[j]] -= 1.0 - Y[B[i]];       // B[i] will donate to B[j] to fix up
            A[B[i]] = B[j];             
            if(Y[B[j]] < 1.0) {             // X[B[j]] now wants less so readjust ordering
                int k = B[i]; B[i] = B[j]; B[j] = k;   // swap B[j] and B[i]
                j++;
            }
            else { i--; }
        }
    }


    public static void main(String[] args) {
        int N = Integer.parseInt(args[0]);
        double[] X = { 0, .1, .4, .12, .38, 0 };
        int[] hist = new int[N+2];
        Discrete d = new Discrete(X);
        for (int i = 0; i < N; i++)
            hist[d.random()]++;

        for (int i = 1; i <= X.length - 2; i++)
            StdOut.println(i + ":  " + X[i] + " " + (1.0 * hist[i] / N));

    }

   
}


Copyright © 2000–2017, Robert Sedgewick and Kevin Wayne.
Last updated: Fri Oct 20 14:12:12 EDT 2017.