array_sample_with_replace_hash.patch

Kenta Murata, 08/03/2010 04:53 PM

Download (4.35 KB)

View differences:

array.c
24 24

  
25 25
static ID id_cmp;
26 26

  
27
static VALUE sym_replace;
28

  
27 29
#define ARY_DEFAULT_SIZE 16
28 30
#define ARY_MAX_SIZE (LONG_MAX / (int)sizeof(VALUE))
29 31

  
......
3787 3789

  
3788 3790
    switch (n) {
3789 3791
      case 0:
3790
        return rb_ary_new2(0);
3792
	return rb_ary_new2(0);
3791 3793
      case 1:
3792
        return rb_ary_new4(1, &ptr[(long)(rb_genrand_real()*len)]);
3794
	return rb_ary_new4(1, &ptr[(long)(rb_genrand_real()*len)]);
3793 3795
      default:
3794
        break;
3796
	break;
3795 3797
    }
3796 3798
    result = rb_ary_new2(n);
3797 3799
    ptr_result = RARRAY_PTR(result);
3798 3800
    RB_GC_GUARD(ary);
3799 3801
    for (i = 0; i < n; ++i) {
3800
        long const j = (long)(rb_genrand_real()*len);
3801
        ptr_result[i] = ptr[j];
3802
	long const j = (long)(rb_genrand_real()*len);
3803
	ptr_result[i] = ptr[j];
3802 3804
    }
3803 3805
    ARY_SET_LEN(result, n);
3804 3806
    return result;
......
3821 3823
static VALUE
3822 3824
rb_ary_sample(int argc, VALUE *argv, VALUE ary)
3823 3825
{
3824
    VALUE nv, replace, result, *ptr;
3826
    VALUE nv, opts, replace=Qfalse, result, *ptr;
3825 3827
    long n, len, i, j, k, idx[10];
3826 3828

  
3827 3829
    len = RARRAY_LEN(ary);
......
3830 3832
	i = len == 1 ? 0 : (long)(rb_genrand_real()*len);
3831 3833
	return RARRAY_PTR(ary)[i];
3832 3834
    }
3833
    rb_scan_args(argc, argv, "12", &nv, &replace);
3835
    rb_scan_args(argc, argv, "12", &nv, &opts);
3834 3836
    n = NUM2LONG(nv);
3835 3837
    if (n < 0) rb_raise(rb_eArgError, "negative sample number");
3836
    if (RTEST(replace)) {
3837
      return ary_sample_with_replace(ary, n);
3838
    if (!NIL_P(opts) && TYPE(opts) == T_HASH) {
3839
	replace = rb_hash_aref(opts, sym_replace);
3838 3840
    }
3841
    if (RTEST(replace))
3842
	return ary_sample_with_replace(ary, n);
3839 3843
    ptr = RARRAY_PTR(ary);
3840 3844
    len = RARRAY_LEN(ary);
3841 3845
    if (n > len) n = len;
......
4641 4645
    rb_define_method(rb_cArray, "drop_while", rb_ary_drop_while, 0);
4642 4646

  
4643 4647
    id_cmp = rb_intern("<=>");
4648
    sym_replace = ID2SYM(rb_intern("replace"));
4644 4649
}
test/ruby/test_array.rb
1911 1911

  
1912 1912
  def test_sample_without_replace
1913 1913
    100.times do
1914
      samples = [2, 1, 0].sample(2, false)
1914
      samples = [2, 1, 0].sample(2, replace: false)
1915 1915
      samples.each{|sample|
1916 1916
        assert([0, 1, 2].include?(sample))
1917 1917
      }
......
1921 1921
    a = (1..18).to_a
1922 1922
    (0..20).each do |n|
1923 1923
      100.times do
1924
        b = a.sample(n, false)
1924
        b = a.sample(n, replace: false)
1925 1925
        assert_equal([n, 18].min, b.size)
1926 1926
        assert_equal(a, (a | b).sort)
1927 1927
        assert_equal(b.sort, (a & b).sort)
......
1929 1929

  
1930 1930
      h = Hash.new(0)
1931 1931
      1000.times do
1932
        a.sample(n, false).each {|x| h[x] += 1 }
1932
        a.sample(n, replace: false).each {|x| h[x] += 1 }
1933 1933
      end
1934 1934
      assert_operator(h.values.min * 2, :>=, h.values.max) if n != 0
1935 1935
    end
1936 1936

  
1937
    assert_raise(ArgumentError, '[ruby-core:23374]') {[1, 2].sample(-1, false)}
1937
    assert_raise(ArgumentError, '[ruby-core:23374]') {[1, 2].sample(-1, replace: false)}
1938 1938
  end
1939 1939

  
1940 1940
  def test_sample_with_replace
1941 1941
    100.times do
1942
      samples = [2, 1, 0].sample(2, true)
1942
      samples = [2, 1, 0].sample(2, replace: true)
1943 1943
      samples.each{|sample|
1944 1944
        assert([0, 1, 2].include?(sample))
1945 1945
      }
......
1949 1949
    a = (1..18).to_a
1950 1950
    (0..20).each do |n|
1951 1951
      100.times do
1952
        b = a.sample(n, true)
1952
        b = a.sample(n, replace: true)
1953 1953
        assert_equal(n, b.size)
1954 1954
        assert_equal(a, (a | b).sort)
1955 1955
        assert_equal(b.sort.uniq, (a & b).sort)
......
1957 1957

  
1958 1958
      h = Hash.new(0)
1959 1959
      1000.times do
1960
        a.sample(n, true).each {|x| h[x] += 1 }
1960
        a.sample(n, replace: true).each {|x| h[x] += 1 }
1961 1961
      end
1962 1962
      assert_operator(h.values.min * 2, :>=, h.values.max) if n != 0
1963 1963
    end
1964 1964

  
1965
    assert_raise(ArgumentError, '[ruby-core:23374]') {[1, 2].sample(-1, true)}
1965
    assert_raise(ArgumentError, '[ruby-core:23374]') {[1, 2].sample(-1, replace: true)}
1966 1966
  end
1967 1967

  
1968 1968
  def test_cycle