Project

General

Profile

Feature #3649 » array_sample_shuffle_random.patch

mrkn (Kenta Murata), 08/03/2010 08:29 PM

View differences:

array.c
3732 3732

  
3733 3733
/*
3734 3734
 *  call-seq:
3735
 *     ary.shuffle!        -> ary
3735
 *     ary.shuffle!                  -> ary
3736
 *     ary.shuffle!(random: RANDOM)  -> ary
3736 3737
 *
3737 3738
 *  Shuffles elements in +self+ in place.
3738 3739
 */
3739 3740

  
3740 3741

  
3741 3742
static VALUE
3742
rb_ary_shuffle_bang(VALUE ary)
3743
rb_ary_shuffle_bang(int argc, VALUE* argv, VALUE ary)
3743 3744
{
3744 3745
    VALUE *ptr;
3746
    VALUE opt, random = Qnil;
3745 3747
    long i = RARRAY_LEN(ary);
3746 3748

  
3749
    if (argc > 0 && !NIL_P(opt = rb_check_hash_type(argv[argc-1]))) {
3750
	--argc;
3751
	random = rb_hash_aref(opt, ID2SYM(rb_intern("random")));
3752
    }
3753
    if (NIL_P(random)) {
3754
	random = rb_const_get_at(rb_cRandom, rb_intern("DEFAULT"));
3755
    }
3747 3756
    rb_ary_modify(ary);
3748 3757
    ptr = RARRAY_PTR(ary);
3749 3758
    while (i) {
3750
	long j = (long)(rb_genrand_real()*i);
3759
	long j = (long)(rb_random_real(random)*(i));
3751 3760
	VALUE tmp = ptr[--i];
3752 3761
	ptr[i] = ptr[j];
3753 3762
	ptr[j] = tmp;
......
3758 3767

  
3759 3768
/*
3760 3769
 *  call-seq:
3761
 *     ary.shuffle -> new_ary
3770
 *     ary.shuffle                 -> new_ary
3771
 *     ary.shuffle(random: random) -> new_ary
3762 3772
 *
3763 3773
 *  Returns a new array with elements of this array shuffled.
3764 3774
 *
......
3767 3777
 */
3768 3778

  
3769 3779
static VALUE
3770
rb_ary_shuffle(VALUE ary)
3780
rb_ary_shuffle(int argc, VALUE* argv, VALUE ary)
3771 3781
{
3772 3782
    ary = rb_ary_dup(ary);
3773
    rb_ary_shuffle_bang(ary);
3783
    rb_ary_shuffle_bang(argc, argv, ary);
3774 3784
    return ary;
3775 3785
}
3776 3786

  
3777 3787

  
3778 3788
/*
3779 3789
 *  call-seq:
3780
 *     ary.sample        -> obj
3781
 *     ary.sample(n)     -> new_ary
3790
 *     ary.sample                     -> obj
3791
 *     ary.sample(random: random)     -> obj
3792
 *     ary.sample(n)                  -> new_ary
3793
 *     ary.sample(n, random: random)  -> new_ary
3782 3794
 *
3783 3795
 *  Choose a random element or +n+ random elements from the array. The elements
3784 3796
 *  are chosen by using random and unique indices into the array in order to
......
3793 3805
rb_ary_sample(int argc, VALUE *argv, VALUE ary)
3794 3806
{
3795 3807
    VALUE nv, result, *ptr;
3808
    VALUE opt, random = Qnil;
3796 3809
    long n, len, i, j, k, idx[10];
3810
#define RAND_UPTO(n) (long)(rb_random_real(random)*(n))
3797 3811

  
3798 3812
    len = RARRAY_LEN(ary);
3813
    if (argc > 0 && !NIL_P(opt = rb_check_hash_type(argv[argc-1]))) {
3814
	--argc;
3815
	random = rb_hash_aref(opt, ID2SYM(rb_intern("random")));
3816
    }
3817
    if (NIL_P(random)) {
3818
	random = rb_const_get_at(rb_cRandom, rb_intern("DEFAULT"));
3819
    }
3799 3820
    if (argc == 0) {
3800 3821
	if (len == 0) return Qnil;
3801
	i = len == 1 ? 0 : (long)(rb_genrand_real()*len);
3822
	i = len == 1 ? 0 : RAND_UPTO(len);
3802 3823
	return RARRAY_PTR(ary)[i];
3803 3824
    }
3804 3825
    rb_scan_args(argc, argv, "1", &nv);
......
3806 3827
    if (n < 0) rb_raise(rb_eArgError, "negative sample number");
3807 3828
    ptr = RARRAY_PTR(ary);
3808 3829
    len = RARRAY_LEN(ary);
3830
    RB_GC_GUARD(ary);
3809 3831
    if (n > len) n = len;
3810 3832
    switch (n) {
3811 3833
      case 0: return rb_ary_new2(0);
3812 3834
      case 1:
3813
	return rb_ary_new4(1, &ptr[(long)(rb_genrand_real()*len)]);
3835
	return rb_ary_new4(1, &ptr[RAND_UPTO(len)]);
3814 3836
      case 2:
3815
	i = (long)(rb_genrand_real()*len);
3816
	j = (long)(rb_genrand_real()*(len-1));
3837
	i = RAND_UPTO(len);
3838
	j = RAND_UPTO(len-1);
3817 3839
	if (j >= i) j++;
3818 3840
	return rb_ary_new3(2, ptr[i], ptr[j]);
3819 3841
      case 3:
3820
	i = (long)(rb_genrand_real()*len);
3821
	j = (long)(rb_genrand_real()*(len-1));
3822
	k = (long)(rb_genrand_real()*(len-2));
3842
	i = RAND_UPTO(len);
3843
	j = RAND_UPTO(len-1);
3844
	k = RAND_UPTO(len-2);
3823 3845
	{
3824 3846
	    long l = j, g = i;
3825 3847
	    if (j >= i) l = i, g = ++j;
......
3830 3852
    if ((size_t)n < sizeof(idx)/sizeof(idx[0])) {
3831 3853
	VALUE *ptr_result;
3832 3854
	long sorted[sizeof(idx)/sizeof(idx[0])];
3833
	sorted[0] = idx[0] = (long)(rb_genrand_real()*len);
3855
	sorted[0] = idx[0] = RAND_UPTO(len);
3834 3856
	for (i=1; i<n; i++) {
3835
	    k = (long)(rb_genrand_real()*--len);
3857
	    k = RAND_UPTO(--len);
3836 3858
	    for (j = 0; j < i; ++j) {
3837 3859
		if (k < sorted[j]) break;
3838 3860
		++k;
......
3850 3872
	VALUE *ptr_result;
3851 3873
	result = rb_ary_new4(len, ptr);
3852 3874
	ptr_result = RARRAY_PTR(result);
3853
	RB_GC_GUARD(ary);
3854 3875
	for (i=0; i<n; i++) {
3855
	    j = (long)(rb_genrand_real()*(len-i)) + i;
3876
	    j = RAND_UPTO(len-i) + i;
3856 3877
	    nv = ptr_result[j];
3857 3878
	    ptr_result[j] = ptr_result[i];
3858 3879
	    ptr_result[i] = nv;
......
3861 3882
    ARY_SET_LEN(result, n);
3862 3883

  
3863 3884
    return result;
3885
#undef RAND_UPTO
3864 3886
}
3865 3887

  
3866 3888

  
......
4593 4615
    rb_define_method(rb_cArray, "flatten", rb_ary_flatten, -1);
4594 4616
    rb_define_method(rb_cArray, "flatten!", rb_ary_flatten_bang, -1);
4595 4617
    rb_define_method(rb_cArray, "count", rb_ary_count, -1);
4596
    rb_define_method(rb_cArray, "shuffle!", rb_ary_shuffle_bang, 0);
4597
    rb_define_method(rb_cArray, "shuffle", rb_ary_shuffle, 0);
4618
    rb_define_method(rb_cArray, "shuffle!", rb_ary_shuffle_bang, -1);
4619
    rb_define_method(rb_cArray, "shuffle", rb_ary_shuffle, -1);
4598 4620
    rb_define_method(rb_cArray, "sample", rb_ary_sample, -1);
4599 4621
    rb_define_method(rb_cArray, "cycle", rb_ary_cycle, -1);
4600 4622
    rb_define_method(rb_cArray, "permutation", rb_ary_permutation, -1);
hash.c
418 418
    return rb_convert_type(hash, T_HASH, "Hash", "to_hash");
419 419
}
420 420

  
421
VALUE
422
rb_check_hash_type(VALUE hash)
423
{
424
    return rb_check_convert_type(hash, T_HASH, "Hash", "to_hash");
425
}
426

  
421 427
/*
422 428
 *  call-seq:
423 429
 *     Hash.try_convert(obj) -> hash or nil
......
432 438
static VALUE
433 439
rb_hash_s_try_convert(VALUE dummy, VALUE hash)
434 440
{
435
    return rb_check_convert_type(hash, T_HASH, "Hash", "to_hash");
441
    return rb_check_hash_type(hash);
436 442
}
437 443

  
438 444
static int
include/ruby/intern.h
406 406
#define Init_stack(addr) ruby_init_stack(addr)
407 407
/* hash.c */
408 408
void st_foreach_safe(struct st_table *, int (*)(ANYARGS), st_data_t);
409
VALUE rb_check_hash_type(VALUE);
409 410
void rb_hash_foreach(VALUE, int (*)(ANYARGS), VALUE);
410 411
VALUE rb_hash(VALUE);
411 412
VALUE rb_hash_new(void);
random.c
329 329
    rb_gc_mark(((rb_random_t *)ptr)->seed);
330 330
}
331 331

  
332
#define random_free RUBY_TYPED_DEFAULT_FREE
332
static void
333
random_free(void *ptr)
334
{
335
    if (ptr != &default_rand)
336
	xfree(ptr);
337
}
333 338

  
334 339
static size_t
335 340
random_memsize(const void *ptr)
......
1232 1237
    rb_define_private_method(rb_cRandom, "state", random_state, 0);
1233 1238
    rb_define_private_method(rb_cRandom, "left", random_left, 0);
1234 1239
    rb_define_method(rb_cRandom, "==", random_equal, 1);
1240
    rb_define_const(rb_cRandom, "DEFAULT",
1241
		    TypedData_Wrap_Struct(rb_cRandom, &random_data_type, &default_rand));
1235 1242

  
1236 1243
    rb_define_singleton_method(rb_cRandom, "srand", rb_f_srand, -1);
1237 1244
    rb_define_singleton_method(rb_cRandom, "rand", rb_f_rand, -1);
test/ruby/test_array.rb
1878 1878
    100.times do
1879 1879
      assert_equal([0, 1, 2], [2, 1, 0].shuffle.sort)
1880 1880
    end
1881

  
1882
    gen = Random.new(0)
1883
    srand(0)
1884
    100.times do
1885
      assert_equal([0, 1, 2].shuffle, [0, 1, 2].shuffle(random: gen))
1886
    end
1881 1887
  end
1882 1888

  
1883 1889
  def test_sample
......
1894 1900
    (0..20).each do |n|
1895 1901
      100.times do
1896 1902
        b = a.sample(n)
1897
        assert_equal([n, 18].min, b.uniq.size)
1903
        assert_equal([n, 18].min, b.size)
1898 1904
        assert_equal(a, (a | b).sort)
1899 1905
        assert_equal(b.sort, (a & b).sort)
1900 1906
      end
......
1907 1913
    end
1908 1914

  
1909 1915
    assert_raise(ArgumentError, '[ruby-core:23374]') {[1, 2].sample(-1)}
1916

  
1917
    gen = Random.new(0)
1918
    srand(0)
1919
    a = (1..18).to_a
1920
    (0..20).each do |n|
1921
      100.times do
1922
        assert_equal(a.sample(n), a.sample(n, random: gen))
1923
      end
1924
    end
1910 1925
  end
1911 1926

  
1912 1927
  def test_cycle