array_sample_with_replace.patch

Kenta Murata, 08/03/2010 01:13 PM

Download (3.64 KB)

View differences:

array.c
3775 3775
}
3776 3776

  
3777 3777

  
3778
static VALUE
3779
ary_sample_with_replace(VALUE const ary, long const n)
3780
{
3781
    VALUE result;
3782
    VALUE* ptr_result;
3783
    long i;
3784

  
3785
    VALUE const* const ptr = RARRAY_PTR(ary);
3786
    long const len = RARRAY_LEN(ary);
3787

  
3788
    switch (n) {
3789
      case 0:
3790
        return rb_ary_new2(0);
3791
      case 1:
3792
        return rb_ary_new4(1, &ptr[(long)(rb_genrand_real()*len)]);
3793
      default:
3794
        break;
3795
    }
3796
    result = rb_ary_new2(n);
3797
    ptr_result = RARRAY_PTR(result);
3798
    RB_GC_GUARD(ary);
3799
    for (i = 0; i < n; ++i) {
3800
        long const j = (long)(rb_genrand_real()*len);
3801
        ptr_result[i] = ptr[j];
3802
    }
3803
    ARY_SET_LEN(result, n);
3804
    return result;
3805
}
3806

  
3778 3807
/*
3779 3808
 *  call-seq:
3780 3809
 *     ary.sample        -> obj
......
3792 3821
static VALUE
3793 3822
rb_ary_sample(int argc, VALUE *argv, VALUE ary)
3794 3823
{
3795
    VALUE nv, result, *ptr;
3824
    VALUE nv, replace, result, *ptr;
3796 3825
    long n, len, i, j, k, idx[10];
3797 3826

  
3798 3827
    len = RARRAY_LEN(ary);
......
3801 3830
	i = len == 1 ? 0 : (long)(rb_genrand_real()*len);
3802 3831
	return RARRAY_PTR(ary)[i];
3803 3832
    }
3804
    rb_scan_args(argc, argv, "1", &nv);
3833
    rb_scan_args(argc, argv, "12", &nv, &replace);
3805 3834
    n = NUM2LONG(nv);
3806 3835
    if (n < 0) rb_raise(rb_eArgError, "negative sample number");
3836
    if (RTEST(replace)) {
3837
      return ary_sample_with_replace(ary, n);
3838
    }
3807 3839
    ptr = RARRAY_PTR(ary);
3808 3840
    len = RARRAY_LEN(ary);
3809 3841
    if (n > len) n = len;
test/ruby/test_array.rb
1894 1894
    (0..20).each do |n|
1895 1895
      100.times do
1896 1896
        b = a.sample(n)
1897
        assert_equal([n, 18].min, b.uniq.size)
1897
        assert_equal([n, 18].min, b.size)
1898 1898
        assert_equal(a, (a | b).sort)
1899 1899
        assert_equal(b.sort, (a & b).sort)
1900 1900
      end
......
1909 1909
    assert_raise(ArgumentError, '[ruby-core:23374]') {[1, 2].sample(-1)}
1910 1910
  end
1911 1911

  
1912
  def test_sample_without_replace
1913
    100.times do
1914
      samples = [2, 1, 0].sample(2, false)
1915
      samples.each{|sample|
1916
        assert([0, 1, 2].include?(sample))
1917
      }
1918
    end
1919

  
1920
    srand(0)
1921
    a = (1..18).to_a
1922
    (0..20).each do |n|
1923
      100.times do
1924
        b = a.sample(n, false)
1925
        assert_equal([n, 18].min, b.size)
1926
        assert_equal(a, (a | b).sort)
1927
        assert_equal(b.sort, (a & b).sort)
1928
      end
1929

  
1930
      h = Hash.new(0)
1931
      1000.times do
1932
        a.sample(n, false).each {|x| h[x] += 1 }
1933
      end
1934
      assert_operator(h.values.min * 2, :>=, h.values.max) if n != 0
1935
    end
1936

  
1937
    assert_raise(ArgumentError, '[ruby-core:23374]') {[1, 2].sample(-1, false)}
1938
  end
1939

  
1940
  def test_sample_with_replace
1941
    100.times do
1942
      samples = [2, 1, 0].sample(2, true)
1943
      samples.each{|sample|
1944
        assert([0, 1, 2].include?(sample))
1945
      }
1946
    end
1947

  
1948
    srand(0)
1949
    a = (1..18).to_a
1950
    (0..20).each do |n|
1951
      100.times do
1952
        b = a.sample(n, true)
1953
        assert_equal(n, b.size)
1954
        assert_equal(a, (a | b).sort)
1955
        assert_equal(b.sort.uniq, (a & b).sort)
1956
      end
1957

  
1958
      h = Hash.new(0)
1959
      1000.times do
1960
        a.sample(n, true).each {|x| h[x] += 1 }
1961
      end
1962
      assert_operator(h.values.min * 2, :>=, h.values.max) if n != 0
1963
    end
1964

  
1965
    assert_raise(ArgumentError, '[ruby-core:23374]') {[1, 2].sample(-1, true)}
1966
  end
1967

  
1912 1968
  def test_cycle
1913 1969
    a = []
1914 1970
    [0, 1, 2].cycle do |i|