diff --git a/array.c b/array.c index 485f617..59d99e9 100644 --- a/array.c +++ b/array.c @@ -24,6 +24,8 @@ VALUE rb_cArray; static ID id_cmp; +static VALUE sym_replace; + #define ARY_DEFAULT_SIZE 16 #define ARY_MAX_SIZE (LONG_MAX / (int)sizeof(VALUE)) @@ -3787,18 +3789,18 @@ ary_sample_with_replace(VALUE const ary, long const n) switch (n) { case 0: - return rb_ary_new2(0); + return rb_ary_new2(0); case 1: - return rb_ary_new4(1, &ptr[(long)(rb_genrand_real()*len)]); + return rb_ary_new4(1, &ptr[(long)(rb_genrand_real()*len)]); default: - break; + break; } result = rb_ary_new2(n); ptr_result = RARRAY_PTR(result); RB_GC_GUARD(ary); for (i = 0; i < n; ++i) { - long const j = (long)(rb_genrand_real()*len); - ptr_result[i] = ptr[j]; + long const j = (long)(rb_genrand_real()*len); + ptr_result[i] = ptr[j]; } ARY_SET_LEN(result, n); return result; @@ -3821,7 +3823,7 @@ ary_sample_with_replace(VALUE const ary, long const n) static VALUE rb_ary_sample(int argc, VALUE *argv, VALUE ary) { - VALUE nv, replace, result, *ptr; + VALUE nv, opts, replace=Qfalse, result, *ptr; long n, len, i, j, k, idx[10]; len = RARRAY_LEN(ary); @@ -3830,12 +3832,14 @@ rb_ary_sample(int argc, VALUE *argv, VALUE ary) i = len == 1 ? 0 : (long)(rb_genrand_real()*len); return RARRAY_PTR(ary)[i]; } - rb_scan_args(argc, argv, "12", &nv, &replace); + rb_scan_args(argc, argv, "12", &nv, &opts); n = NUM2LONG(nv); if (n < 0) rb_raise(rb_eArgError, "negative sample number"); - if (RTEST(replace)) { - return ary_sample_with_replace(ary, n); + if (!NIL_P(opts) && TYPE(opts) == T_HASH) { + replace = rb_hash_aref(opts, sym_replace); } + if (RTEST(replace)) + return ary_sample_with_replace(ary, n); ptr = RARRAY_PTR(ary); len = RARRAY_LEN(ary); if (n > len) n = len; @@ -4641,4 +4645,5 @@ Init_Array(void) rb_define_method(rb_cArray, "drop_while", rb_ary_drop_while, 0); id_cmp = rb_intern("<=>"); + sym_replace = ID2SYM(rb_intern("replace")); } diff --git a/test/ruby/test_array.rb b/test/ruby/test_array.rb index 837ce7b..6409cc2 100644 --- a/test/ruby/test_array.rb +++ b/test/ruby/test_array.rb @@ -1911,7 +1911,7 @@ class TestArray < Test::Unit::TestCase def test_sample_without_replace 100.times do - samples = [2, 1, 0].sample(2, false) + samples = [2, 1, 0].sample(2, replace: false) samples.each{|sample| assert([0, 1, 2].include?(sample)) } @@ -1921,7 +1921,7 @@ class TestArray < Test::Unit::TestCase a = (1..18).to_a (0..20).each do |n| 100.times do - b = a.sample(n, false) + b = a.sample(n, replace: false) assert_equal([n, 18].min, b.size) assert_equal(a, (a | b).sort) assert_equal(b.sort, (a & b).sort) @@ -1929,17 +1929,17 @@ class TestArray < Test::Unit::TestCase h = Hash.new(0) 1000.times do - a.sample(n, false).each {|x| h[x] += 1 } + a.sample(n, replace: false).each {|x| h[x] += 1 } end assert_operator(h.values.min * 2, :>=, h.values.max) if n != 0 end - assert_raise(ArgumentError, '[ruby-core:23374]') {[1, 2].sample(-1, false)} + assert_raise(ArgumentError, '[ruby-core:23374]') {[1, 2].sample(-1, replace: false)} end def test_sample_with_replace 100.times do - samples = [2, 1, 0].sample(2, true) + samples = [2, 1, 0].sample(2, replace: true) samples.each{|sample| assert([0, 1, 2].include?(sample)) } @@ -1949,7 +1949,7 @@ class TestArray < Test::Unit::TestCase a = (1..18).to_a (0..20).each do |n| 100.times do - b = a.sample(n, true) + b = a.sample(n, replace: true) assert_equal(n, b.size) assert_equal(a, (a | b).sort) assert_equal(b.sort.uniq, (a & b).sort) @@ -1957,12 +1957,12 @@ class TestArray < Test::Unit::TestCase h = Hash.new(0) 1000.times do - a.sample(n, true).each {|x| h[x] += 1 } + a.sample(n, replace: true).each {|x| h[x] += 1 } end assert_operator(h.values.min * 2, :>=, h.values.max) if n != 0 end - assert_raise(ArgumentError, '[ruby-core:23374]') {[1, 2].sample(-1, true)} + assert_raise(ArgumentError, '[ruby-core:23374]') {[1, 2].sample(-1, replace: true)} end def test_cycle