array_sample_with_replace_hash.patch
array.c  

24  24  
25  25 
static ID id_cmp; 
26  26  
27 
static VALUE sym_replace; 

28  
27  29 
#define ARY_DEFAULT_SIZE 16 
28  30 
#define ARY_MAX_SIZE (LONG_MAX / (int)sizeof(VALUE)) 
29  31  
...  ...  
3787  3789  
3788  3790 
switch (n) { 
3789  3791 
case 0: 
3790 
return rb_ary_new2(0);


3792 
return rb_ary_new2(0);


3791  3793 
case 1: 
3792 
return rb_ary_new4(1, &ptr[(long)(rb_genrand_real()*len)]);


3794 
return rb_ary_new4(1, &ptr[(long)(rb_genrand_real()*len)]);


3793  3795 
default: 
3794 
break;


3796 
break;


3795  3797 
} 
3796  3798 
result = rb_ary_new2(n); 
3797  3799 
ptr_result = RARRAY_PTR(result); 
3798  3800 
RB_GC_GUARD(ary); 
3799  3801 
for (i = 0; i < n; ++i) { 
3800 
long const j = (long)(rb_genrand_real()*len);


3801 
ptr_result[i] = ptr[j];


3802 
long const j = (long)(rb_genrand_real()*len);


3803 
ptr_result[i] = ptr[j];


3802  3804 
} 
3803  3805 
ARY_SET_LEN(result, n); 
3804  3806 
return result; 
...  ...  
3821  3823 
static VALUE 
3822  3824 
rb_ary_sample(int argc, VALUE *argv, VALUE ary) 
3823  3825 
{ 
3824 
VALUE nv, replace, result, *ptr;


3826 
VALUE nv, opts, replace=Qfalse, result, *ptr;


3825  3827 
long n, len, i, j, k, idx[10]; 
3826  3828  
3827  3829 
len = RARRAY_LEN(ary); 
...  ...  
3830  3832 
i = len == 1 ? 0 : (long)(rb_genrand_real()*len); 
3831  3833 
return RARRAY_PTR(ary)[i]; 
3832  3834 
} 
3833 
rb_scan_args(argc, argv, "12", &nv, &replace);


3835 
rb_scan_args(argc, argv, "12", &nv, &opts);


3834  3836 
n = NUM2LONG(nv); 
3835  3837 
if (n < 0) rb_raise(rb_eArgError, "negative sample number"); 
3836 
if (RTEST(replace)) {


3837 
return ary_sample_with_replace(ary, n);


3838 
if (!NIL_P(opts) && TYPE(opts) == T_HASH) {


3839 
replace = rb_hash_aref(opts, sym_replace);


3838  3840 
} 
3841 
if (RTEST(replace)) 

3842 
return ary_sample_with_replace(ary, n); 

3839  3843 
ptr = RARRAY_PTR(ary); 
3840  3844 
len = RARRAY_LEN(ary); 
3841  3845 
if (n > len) n = len; 
...  ...  
4641  4645 
rb_define_method(rb_cArray, "drop_while", rb_ary_drop_while, 0); 
4642  4646  
4643  4647 
id_cmp = rb_intern("<=>"); 
4648 
sym_replace = ID2SYM(rb_intern("replace")); 

4644  4649 
} 
test/ruby/test_array.rb  

1911  1911  
1912  1912 
def test_sample_without_replace 
1913  1913 
100.times do 
1914 
samples = [2, 1, 0].sample(2, false) 

1914 
samples = [2, 1, 0].sample(2, replace: false)


1915  1915 
samples.each{sample 
1916  1916 
assert([0, 1, 2].include?(sample)) 
1917  1917 
} 
...  ...  
1921  1921 
a = (1..18).to_a 
1922  1922 
(0..20).each do n 
1923  1923 
100.times do 
1924 
b = a.sample(n, false) 

1924 
b = a.sample(n, replace: false)


1925  1925 
assert_equal([n, 18].min, b.size) 
1926  1926 
assert_equal(a, (a  b).sort) 
1927  1927 
assert_equal(b.sort, (a & b).sort) 
...  ...  
1929  1929  
1930  1930 
h = Hash.new(0) 
1931  1931 
1000.times do 
1932 
a.sample(n, false).each {x h[x] += 1 } 

1932 
a.sample(n, replace: false).each {x h[x] += 1 }


1933  1933 
end 
1934  1934 
assert_operator(h.values.min * 2, :>=, h.values.max) if n != 0 
1935  1935 
end 
1936  1936  
1937 
assert_raise(ArgumentError, '[rubycore:23374]') {[1, 2].sample(1, false)} 

1937 
assert_raise(ArgumentError, '[rubycore:23374]') {[1, 2].sample(1, replace: false)}


1938  1938 
end 
1939  1939  
1940  1940 
def test_sample_with_replace 
1941  1941 
100.times do 
1942 
samples = [2, 1, 0].sample(2, true) 

1942 
samples = [2, 1, 0].sample(2, replace: true)


1943  1943 
samples.each{sample 
1944  1944 
assert([0, 1, 2].include?(sample)) 
1945  1945 
} 
...  ...  
1949  1949 
a = (1..18).to_a 
1950  1950 
(0..20).each do n 
1951  1951 
100.times do 
1952 
b = a.sample(n, true) 

1952 
b = a.sample(n, replace: true)


1953  1953 
assert_equal(n, b.size) 
1954  1954 
assert_equal(a, (a  b).sort) 
1955  1955 
assert_equal(b.sort.uniq, (a & b).sort) 
...  ...  
1957  1957  
1958  1958 
h = Hash.new(0) 
1959  1959 
1000.times do 
1960 
a.sample(n, true).each {x h[x] += 1 } 

1960 
a.sample(n, replace: true).each {x h[x] += 1 }


1961  1961 
end 
1962  1962 
assert_operator(h.values.min * 2, :>=, h.values.max) if n != 0 
1963  1963 
end 
1964  1964  
1965 
assert_raise(ArgumentError, '[rubycore:23374]') {[1, 2].sample(1, true)} 

1965 
assert_raise(ArgumentError, '[rubycore:23374]') {[1, 2].sample(1, replace: true)}


1966  1966 
end 
1967  1967  
1968  1968 
def test_cycle 