diff --git a/array.c b/array.c index c2a4e85..0c7c093 100644 --- a/array.c +++ b/array.c @@ -3732,22 +3732,33 @@ rb_ary_flatten(int argc, VALUE *argv, VALUE ary) /* * call-seq: - * ary.shuffle! -> ary + * ary.shuffle! -> ary + * ary.shuffle!(random: RANDOM) -> ary * * Shuffles elements in +self+ in place. + * Unless an Random object is passed with :random keyword, + * use default random number generator. */ static VALUE -rb_ary_shuffle_bang(VALUE ary) +rb_ary_shuffle_bang(int argc, VALUE* argv, VALUE ary) { VALUE *ptr; + VALUE opt, random = Qnil; long i = RARRAY_LEN(ary); + if (argc > 0 && !NIL_P(opt = rb_check_hash_type(argv[argc-1]))) { + --argc; + random = rb_hash_aref(opt, ID2SYM(rb_intern("random"))); + } + if (NIL_P(random)) { + random = rb_const_get_at(rb_cRandom, rb_intern("DEFAULT")); + } rb_ary_modify(ary); ptr = RARRAY_PTR(ary); while (i) { - long j = (long)(rb_genrand_real()*i); + long j = (long)(rb_random_real(random)*(i)); VALUE tmp = ptr[--i]; ptr[i] = ptr[j]; ptr[j] = tmp; @@ -3767,10 +3778,10 @@ rb_ary_shuffle_bang(VALUE ary) */ static VALUE -rb_ary_shuffle(VALUE ary) +rb_ary_shuffle(int argc, VALUE* argv, VALUE ary) { ary = rb_ary_dup(ary); - rb_ary_shuffle_bang(ary); + rb_ary_shuffle_bang(argc, argv, ary); return ary; } @@ -3793,12 +3804,22 @@ static VALUE rb_ary_sample(int argc, VALUE *argv, VALUE ary) { VALUE nv, result, *ptr; + VALUE opt, replace = Qnil, random = Qnil; long n, len, i, j, k, idx[10]; +#define RAND_UPTO(n) (long)(rb_random_real(random)*(n)) len = RARRAY_LEN(ary); + if (argc > 0 && !NIL_P(opt = rb_check_hash_type(argv[argc-1]))) { + --argc; + replace = rb_hash_aref(opt, ID2SYM(rb_intern("replace"))); + random = rb_hash_aref(opt, ID2SYM(rb_intern("random"))); + } + if (NIL_P(random)) { + random = rb_const_get_at(rb_cRandom, rb_intern("DEFAULT")); + } if (argc == 0) { if (len == 0) return Qnil; - i = len == 1 ? 0 : (long)(rb_genrand_real()*len); + i = len == 1 ? 0 : RAND_UPTO(len); return RARRAY_PTR(ary)[i]; } rb_scan_args(argc, argv, "1", &nv); @@ -3806,20 +3827,28 @@ rb_ary_sample(int argc, VALUE *argv, VALUE ary) if (n < 0) rb_raise(rb_eArgError, "negative sample number"); ptr = RARRAY_PTR(ary); len = RARRAY_LEN(ary); + RB_GC_GUARD(ary); + if (RTEST(replace)) { + result = rb_ary_new2(n); + while (n-- > 0) { + rb_ary_push(result, ptr[RAND_UPTO(len)]); + } + return result; + } if (n > len) n = len; switch (n) { case 0: return rb_ary_new2(0); case 1: - return rb_ary_new4(1, &ptr[(long)(rb_genrand_real()*len)]); + return rb_ary_new4(1, &ptr[RAND_UPTO(len)]); case 2: - i = (long)(rb_genrand_real()*len); - j = (long)(rb_genrand_real()*(len-1)); + i = RAND_UPTO(len); + j = RAND_UPTO(len-1); if (j >= i) j++; return rb_ary_new3(2, ptr[i], ptr[j]); case 3: - i = (long)(rb_genrand_real()*len); - j = (long)(rb_genrand_real()*(len-1)); - k = (long)(rb_genrand_real()*(len-2)); + i = RAND_UPTO(len); + j = RAND_UPTO(len-1); + k = RAND_UPTO(len-2); { long l = j, g = i; if (j >= i) l = i, g = ++j; @@ -3830,9 +3859,9 @@ rb_ary_sample(int argc, VALUE *argv, VALUE ary) if ((size_t)n < sizeof(idx)/sizeof(idx[0])) { VALUE *ptr_result; long sorted[sizeof(idx)/sizeof(idx[0])]; - sorted[0] = idx[0] = (long)(rb_genrand_real()*len); + sorted[0] = idx[0] = RAND_UPTO(len); for (i=1; i hash or nil @@ -432,7 +438,7 @@ to_hash(VALUE hash) static VALUE rb_hash_s_try_convert(VALUE dummy, VALUE hash) { - return rb_check_convert_type(hash, T_HASH, "Hash", "to_hash"); + return rb_check_hash_type(hash); } static int diff --git a/include/ruby/intern.h b/include/ruby/intern.h index c977a4b..748df32 100644 --- a/include/ruby/intern.h +++ b/include/ruby/intern.h @@ -406,6 +406,7 @@ VALUE rb_gc_start(void); #define Init_stack(addr) ruby_init_stack(addr) /* hash.c */ void st_foreach_safe(struct st_table *, int (*)(ANYARGS), st_data_t); +VALUE rb_check_hash_type(VALUE); void rb_hash_foreach(VALUE, int (*)(ANYARGS), VALUE); VALUE rb_hash(VALUE); VALUE rb_hash_new(void); diff --git a/random.c b/random.c index 49d0a75..ef0de1f 100644 --- a/random.c +++ b/random.c @@ -329,7 +329,12 @@ random_mark(void *ptr) rb_gc_mark(((rb_random_t *)ptr)->seed); } -#define random_free RUBY_TYPED_DEFAULT_FREE +static void +random_free(void *ptr) +{ + if (ptr != &default_rand) + xfree(ptr); +} static size_t random_memsize(const void *ptr) @@ -1232,6 +1237,8 @@ Init_Random(void) rb_define_private_method(rb_cRandom, "state", random_state, 0); rb_define_private_method(rb_cRandom, "left", random_left, 0); rb_define_method(rb_cRandom, "==", random_equal, 1); + rb_define_const(rb_cRandom, "DEFAULT", + TypedData_Wrap_Struct(rb_cRandom, &random_data_type, &default_rand)); rb_define_singleton_method(rb_cRandom, "srand", rb_f_srand, -1); rb_define_singleton_method(rb_cRandom, "rand", rb_f_rand, -1); diff --git a/test/ruby/test_array.rb b/test/ruby/test_array.rb index e8edcc2..b289af9 100644 --- a/test/ruby/test_array.rb +++ b/test/ruby/test_array.rb @@ -1878,6 +1878,12 @@ class TestArray < Test::Unit::TestCase 100.times do assert_equal([0, 1, 2], [2, 1, 0].shuffle.sort) end + + gen = Random.new(0) + srand(0) + 100.times do + assert_equal([0, 1, 2].shuffle, [0, 1, 2].shuffle(random: gen)) + end end def test_sample @@ -1894,7 +1900,7 @@ class TestArray < Test::Unit::TestCase (0..20).each do |n| 100.times do b = a.sample(n) - assert_equal([n, 18].min, b.uniq.size) + assert_equal([n, 18].min, b.size) assert_equal(a, (a | b).sort) assert_equal(b.sort, (a & b).sort) end @@ -1907,6 +1913,89 @@ class TestArray < Test::Unit::TestCase end assert_raise(ArgumentError, '[ruby-core:23374]') {[1, 2].sample(-1)} + + gen = Random.new(0) + srand(0) + a = (1..18).to_a + (0..20).each do |n| + 100.times do + assert_equal(a.sample(n), a.sample(n, random: gen)) + end + end + end + + def test_sample_without_replace + 100.times do + samples = [2, 1, 0].sample(2, replace: false) + samples.each{|sample| + assert([0, 1, 2].include?(sample)) + } + end + + srand(0) + a = (1..18).to_a + (0..20).each do |n| + 100.times do + b = a.sample(n, replace: false) + assert_equal([n, 18].min, b.size) + assert_equal(a, (a | b).sort) + assert_equal(b.sort, (a & b).sort) + end + + h = Hash.new(0) + 1000.times do + a.sample(n, replace: false).each {|x| h[x] += 1 } + end + assert_operator(h.values.min * 2, :>=, h.values.max) if n != 0 + end + + assert_raise(ArgumentError, '[ruby-core:23374]') {[1, 2].sample(-1, replace: false)} + + gen = Random.new(0) + srand(0) + a = (1..18).to_a + (0..20).each do |n| + 100.times do + assert_equal(a.sample(n, replace: false), a.sample(n, replace: false, random: gen)) + end + end + end + + def test_sample_with_replace + 100.times do + samples = [2, 1, 0].sample(2, replace: true) + samples.each{|sample| + assert([0, 1, 2].include?(sample)) + } + end + + srand(0) + a = (1..18).to_a + (0..20).each do |n| + 100.times do + b = a.sample(n, replace: true) + assert_equal(n, b.size) + assert_equal(a, (a | b).sort) + assert_equal(b.sort.uniq, (a & b).sort) + end + + h = Hash.new(0) + 1000.times do + a.sample(n, replace: true).each {|x| h[x] += 1 } + end + assert_operator(h.values.min * 2, :>=, h.values.max) if n != 0 + end + + assert_raise(ArgumentError, '[ruby-core:23374]') {[1, 2].sample(-1, replace: true)} + + gen = Random.new(0) + srand(0) + a = (1..18).to_a + (0..20).each do |n| + 100.times do + assert_equal(a.sample(n, replace: true), a.sample(n, replace: true, random: gen)) + end + end end def test_cycle