Feature #3647 » array_sample_with_replace_hash.patch
array.c  

static ID id_cmp;


static VALUE sym_replace;


#define ARY_DEFAULT_SIZE 16


#define ARY_MAX_SIZE (LONG_MAX / (int)sizeof(VALUE))


...  ...  
switch (n) {


case 0:


return rb_ary_new2(0);


return rb_ary_new2(0);


case 1:


return rb_ary_new4(1, &ptr[(long)(rb_genrand_real()*len)]);


return rb_ary_new4(1, &ptr[(long)(rb_genrand_real()*len)]);


default:


break;


break;


}


result = rb_ary_new2(n);


ptr_result = RARRAY_PTR(result);


RB_GC_GUARD(ary);


for (i = 0; i < n; ++i) {


long const j = (long)(rb_genrand_real()*len);


ptr_result[i] = ptr[j];


long const j = (long)(rb_genrand_real()*len);


ptr_result[i] = ptr[j];


}


ARY_SET_LEN(result, n);


return result;


...  ...  
static VALUE


rb_ary_sample(int argc, VALUE *argv, VALUE ary)


{


VALUE nv, replace, result, *ptr;


VALUE nv, opts, replace=Qfalse, result, *ptr;


long n, len, i, j, k, idx[10];


len = RARRAY_LEN(ary);


...  ...  
i = len == 1 ? 0 : (long)(rb_genrand_real()*len);


return RARRAY_PTR(ary)[i];


}


rb_scan_args(argc, argv, "12", &nv, &replace);


rb_scan_args(argc, argv, "12", &nv, &opts);


n = NUM2LONG(nv);


if (n < 0) rb_raise(rb_eArgError, "negative sample number");


if (RTEST(replace)) {


return ary_sample_with_replace(ary, n);


if (!NIL_P(opts) && TYPE(opts) == T_HASH) {


replace = rb_hash_aref(opts, sym_replace);


}


if (RTEST(replace))


return ary_sample_with_replace(ary, n);


ptr = RARRAY_PTR(ary);


len = RARRAY_LEN(ary);


if (n > len) n = len;


...  ...  
rb_define_method(rb_cArray, "drop_while", rb_ary_drop_while, 0);


id_cmp = rb_intern("<=>");


sym_replace = ID2SYM(rb_intern("replace"));


}

test/ruby/test_array.rb  

def test_sample_without_replace


100.times do


samples = [2, 1, 0].sample(2, false)


samples = [2, 1, 0].sample(2, replace: false)


samples.each{sample


assert([0, 1, 2].include?(sample))


}


...  ...  
a = (1..18).to_a


(0..20).each do n


100.times do


b = a.sample(n, false)


b = a.sample(n, replace: false)


assert_equal([n, 18].min, b.size)


assert_equal(a, (a  b).sort)


assert_equal(b.sort, (a & b).sort)


...  ...  
h = Hash.new(0)


1000.times do


a.sample(n, false).each {x h[x] += 1 }


a.sample(n, replace: false).each {x h[x] += 1 }


end


assert_operator(h.values.min * 2, :>=, h.values.max) if n != 0


end


assert_raise(ArgumentError, '[rubycore:23374]') {[1, 2].sample(1, false)}


assert_raise(ArgumentError, '[rubycore:23374]') {[1, 2].sample(1, replace: false)}


end


def test_sample_with_replace


100.times do


samples = [2, 1, 0].sample(2, true)


samples = [2, 1, 0].sample(2, replace: true)


samples.each{sample


assert([0, 1, 2].include?(sample))


}


...  ...  
a = (1..18).to_a


(0..20).each do n


100.times do


b = a.sample(n, true)


b = a.sample(n, replace: true)


assert_equal(n, b.size)


assert_equal(a, (a  b).sort)


assert_equal(b.sort.uniq, (a & b).sort)


...  ...  
h = Hash.new(0)


1000.times do


a.sample(n, true).each {x h[x] += 1 }


a.sample(n, replace: true).each {x h[x] += 1 }


end


assert_operator(h.values.min * 2, :>=, h.values.max) if n != 0


end


assert_raise(ArgumentError, '[rubycore:23374]') {[1, 2].sample(1, true)}


assert_raise(ArgumentError, '[rubycore:23374]') {[1, 2].sample(1, replace: true)}


end


def test_cycle
