array_sample_shuffle.patch

Kenta Murata, 08/03/2010 06:25 PM

Download (9.62 KB)

View differences:

array.c
3732 3732

  
3733 3733
/*
3734 3734
 *  call-seq:
3735
 *     ary.shuffle!        -> ary
3735
 *     ary.shuffle!                  -> ary
3736
 *     ary.shuffle!(random: RANDOM)  -> ary
3736 3737
 *
3737 3738
 *  Shuffles elements in +self+ in place.
3739
 *  Unless an Random object is passed with :random keyword,
3740
 *  use default random number generator.
3738 3741
 */
3739 3742

  
3740 3743

  
3741 3744
static VALUE
3742
rb_ary_shuffle_bang(VALUE ary)
3745
rb_ary_shuffle_bang(int argc, VALUE* argv, VALUE ary)
3743 3746
{
3744 3747
    VALUE *ptr;
3748
    VALUE opt, random = Qnil;
3745 3749
    long i = RARRAY_LEN(ary);
3746 3750

  
3751
    if (argc > 0 && !NIL_P(opt = rb_check_hash_type(argv[argc-1]))) {
3752
	--argc;
3753
	random = rb_hash_aref(opt, ID2SYM(rb_intern("random")));
3754
    }
3755
    if (NIL_P(random)) {
3756
	random = rb_const_get_at(rb_cRandom, rb_intern("DEFAULT"));
3757
    }
3747 3758
    rb_ary_modify(ary);
3748 3759
    ptr = RARRAY_PTR(ary);
3749 3760
    while (i) {
3750
	long j = (long)(rb_genrand_real()*i);
3761
	long j = (long)(rb_random_real(random)*(i));
3751 3762
	VALUE tmp = ptr[--i];
3752 3763
	ptr[i] = ptr[j];
3753 3764
	ptr[j] = tmp;
......
3767 3778
 */
3768 3779

  
3769 3780
static VALUE
3770
rb_ary_shuffle(VALUE ary)
3781
rb_ary_shuffle(int argc, VALUE* argv, VALUE ary)
3771 3782
{
3772 3783
    ary = rb_ary_dup(ary);
3773
    rb_ary_shuffle_bang(ary);
3784
    rb_ary_shuffle_bang(argc, argv, ary);
3774 3785
    return ary;
3775 3786
}
3776 3787

  
......
3793 3804
rb_ary_sample(int argc, VALUE *argv, VALUE ary)
3794 3805
{
3795 3806
    VALUE nv, result, *ptr;
3807
    VALUE opt, replace = Qnil, random = Qnil;
3796 3808
    long n, len, i, j, k, idx[10];
3809
#define RAND_UPTO(n) (long)(rb_random_real(random)*(n))
3797 3810

  
3798 3811
    len = RARRAY_LEN(ary);
3812
    if (argc > 0 && !NIL_P(opt = rb_check_hash_type(argv[argc-1]))) {
3813
	--argc;
3814
	replace = rb_hash_aref(opt, ID2SYM(rb_intern("replace")));
3815
	random = rb_hash_aref(opt, ID2SYM(rb_intern("random")));
3816
    }
3817
    if (NIL_P(random)) {
3818
	random = rb_const_get_at(rb_cRandom, rb_intern("DEFAULT"));
3819
    }
3799 3820
    if (argc == 0) {
3800 3821
	if (len == 0) return Qnil;
3801
	i = len == 1 ? 0 : (long)(rb_genrand_real()*len);
3822
	i = len == 1 ? 0 : RAND_UPTO(len);
3802 3823
	return RARRAY_PTR(ary)[i];
3803 3824
    }
3804 3825
    rb_scan_args(argc, argv, "1", &nv);
......
3806 3827
    if (n < 0) rb_raise(rb_eArgError, "negative sample number");
3807 3828
    ptr = RARRAY_PTR(ary);
3808 3829
    len = RARRAY_LEN(ary);
3830
    RB_GC_GUARD(ary);
3831
    if (RTEST(replace)) {
3832
	result = rb_ary_new2(n);
3833
	while (n-- > 0) {
3834
	    rb_ary_push(result, ptr[RAND_UPTO(len)]);
3835
	}
3836
	return result;
3837
    }
3809 3838
    if (n > len) n = len;
3810 3839
    switch (n) {
3811 3840
      case 0: return rb_ary_new2(0);
3812 3841
      case 1:
3813
	return rb_ary_new4(1, &ptr[(long)(rb_genrand_real()*len)]);
3842
	return rb_ary_new4(1, &ptr[RAND_UPTO(len)]);
3814 3843
      case 2:
3815
	i = (long)(rb_genrand_real()*len);
3816
	j = (long)(rb_genrand_real()*(len-1));
3844
	i = RAND_UPTO(len);
3845
	j = RAND_UPTO(len-1);
3817 3846
	if (j >= i) j++;
3818 3847
	return rb_ary_new3(2, ptr[i], ptr[j]);
3819 3848
      case 3:
3820
	i = (long)(rb_genrand_real()*len);
3821
	j = (long)(rb_genrand_real()*(len-1));
3822
	k = (long)(rb_genrand_real()*(len-2));
3849
	i = RAND_UPTO(len);
3850
	j = RAND_UPTO(len-1);
3851
	k = RAND_UPTO(len-2);
3823 3852
	{
3824 3853
	    long l = j, g = i;
3825 3854
	    if (j >= i) l = i, g = ++j;
......
3830 3859
    if ((size_t)n < sizeof(idx)/sizeof(idx[0])) {
3831 3860
	VALUE *ptr_result;
3832 3861
	long sorted[sizeof(idx)/sizeof(idx[0])];
3833
	sorted[0] = idx[0] = (long)(rb_genrand_real()*len);
3862
	sorted[0] = idx[0] = RAND_UPTO(len);
3834 3863
	for (i=1; i<n; i++) {
3835
	    k = (long)(rb_genrand_real()*--len);
3864
	    k = RAND_UPTO(--len);
3836 3865
	    for (j = 0; j < i; ++j) {
3837 3866
		if (k < sorted[j]) break;
3838 3867
		++k;
......
3852 3881
	ptr_result = RARRAY_PTR(result);
3853 3882
	RB_GC_GUARD(ary);
3854 3883
	for (i=0; i<n; i++) {
3855
	    j = (long)(rb_genrand_real()*(len-i)) + i;
3884
	    j = RAND_UPTO(len-i) + i;
3856 3885
	    nv = ptr_result[j];
3857 3886
	    ptr_result[j] = ptr_result[i];
3858 3887
	    ptr_result[i] = nv;
......
3861 3890
    ARY_SET_LEN(result, n);
3862 3891

  
3863 3892
    return result;
3893
#undef RAND_UPTO
3864 3894
}
3865 3895

  
3866 3896

  
......
4593 4623
    rb_define_method(rb_cArray, "flatten", rb_ary_flatten, -1);
4594 4624
    rb_define_method(rb_cArray, "flatten!", rb_ary_flatten_bang, -1);
4595 4625
    rb_define_method(rb_cArray, "count", rb_ary_count, -1);
4596
    rb_define_method(rb_cArray, "shuffle!", rb_ary_shuffle_bang, 0);
4597
    rb_define_method(rb_cArray, "shuffle", rb_ary_shuffle, 0);
4626
    rb_define_method(rb_cArray, "shuffle!", rb_ary_shuffle_bang, -1);
4627
    rb_define_method(rb_cArray, "shuffle", rb_ary_shuffle, -1);
4598 4628
    rb_define_method(rb_cArray, "sample", rb_ary_sample, -1);
4599 4629
    rb_define_method(rb_cArray, "cycle", rb_ary_cycle, -1);
4600 4630
    rb_define_method(rb_cArray, "permutation", rb_ary_permutation, -1);
hash.c
418 418
    return rb_convert_type(hash, T_HASH, "Hash", "to_hash");
419 419
}
420 420

  
421
VALUE
422
rb_check_hash_type(VALUE hash)
423
{
424
    return rb_check_convert_type(hash, T_HASH, "Hash", "to_hash");
425
}
426

  
421 427
/*
422 428
 *  call-seq:
423 429
 *     Hash.try_convert(obj) -> hash or nil
......
432 438
static VALUE
433 439
rb_hash_s_try_convert(VALUE dummy, VALUE hash)
434 440
{
435
    return rb_check_convert_type(hash, T_HASH, "Hash", "to_hash");
441
    return rb_check_hash_type(hash);
436 442
}
437 443

  
438 444
static int
include/ruby/intern.h
406 406
#define Init_stack(addr) ruby_init_stack(addr)
407 407
/* hash.c */
408 408
void st_foreach_safe(struct st_table *, int (*)(ANYARGS), st_data_t);
409
VALUE rb_check_hash_type(VALUE);
409 410
void rb_hash_foreach(VALUE, int (*)(ANYARGS), VALUE);
410 411
VALUE rb_hash(VALUE);
411 412
VALUE rb_hash_new(void);
random.c
329 329
    rb_gc_mark(((rb_random_t *)ptr)->seed);
330 330
}
331 331

  
332
#define random_free RUBY_TYPED_DEFAULT_FREE
332
static void
333
random_free(void *ptr)
334
{
335
    if (ptr != &default_rand)
336
	xfree(ptr);
337
}
333 338

  
334 339
static size_t
335 340
random_memsize(const void *ptr)
......
1232 1237
    rb_define_private_method(rb_cRandom, "state", random_state, 0);
1233 1238
    rb_define_private_method(rb_cRandom, "left", random_left, 0);
1234 1239
    rb_define_method(rb_cRandom, "==", random_equal, 1);
1240
    rb_define_const(rb_cRandom, "DEFAULT",
1241
		    TypedData_Wrap_Struct(rb_cRandom, &random_data_type, &default_rand));
1235 1242

  
1236 1243
    rb_define_singleton_method(rb_cRandom, "srand", rb_f_srand, -1);
1237 1244
    rb_define_singleton_method(rb_cRandom, "rand", rb_f_rand, -1);
test/ruby/test_array.rb
1878 1878
    100.times do
1879 1879
      assert_equal([0, 1, 2], [2, 1, 0].shuffle.sort)
1880 1880
    end
1881

  
1882
    gen = Random.new(0)
1883
    srand(0)
1884
    100.times do
1885
      assert_equal([0, 1, 2].shuffle, [0, 1, 2].shuffle(random: gen))
1886
    end
1881 1887
  end
1882 1888

  
1883 1889
  def test_sample
......
1894 1900
    (0..20).each do |n|
1895 1901
      100.times do
1896 1902
        b = a.sample(n)
1897
        assert_equal([n, 18].min, b.uniq.size)
1903
        assert_equal([n, 18].min, b.size)
1898 1904
        assert_equal(a, (a | b).sort)
1899 1905
        assert_equal(b.sort, (a & b).sort)
1900 1906
      end
......
1907 1913
    end
1908 1914

  
1909 1915
    assert_raise(ArgumentError, '[ruby-core:23374]') {[1, 2].sample(-1)}
1916

  
1917
    gen = Random.new(0)
1918
    srand(0)
1919
    a = (1..18).to_a
1920
    (0..20).each do |n|
1921
      100.times do
1922
        assert_equal(a.sample(n), a.sample(n, random: gen))
1923
      end
1924
    end
1925
  end
1926

  
1927
  def test_sample_without_replace
1928
    100.times do
1929
      samples = [2, 1, 0].sample(2, replace: false)
1930
      samples.each{|sample|
1931
        assert([0, 1, 2].include?(sample))
1932
      }
1933
    end
1934

  
1935
    srand(0)
1936
    a = (1..18).to_a
1937
    (0..20).each do |n|
1938
      100.times do
1939
        b = a.sample(n, replace: false)
1940
        assert_equal([n, 18].min, b.size)
1941
        assert_equal(a, (a | b).sort)
1942
        assert_equal(b.sort, (a & b).sort)
1943
      end
1944

  
1945
      h = Hash.new(0)
1946
      1000.times do
1947
        a.sample(n, replace: false).each {|x| h[x] += 1 }
1948
      end
1949
      assert_operator(h.values.min * 2, :>=, h.values.max) if n != 0
1950
    end
1951

  
1952
    assert_raise(ArgumentError, '[ruby-core:23374]') {[1, 2].sample(-1, replace: false)}
1953

  
1954
    gen = Random.new(0)
1955
    srand(0)
1956
    a = (1..18).to_a
1957
    (0..20).each do |n|
1958
      100.times do
1959
        assert_equal(a.sample(n, replace: false), a.sample(n, replace: false, random: gen))
1960
      end
1961
    end
1962
  end
1963

  
1964
  def test_sample_with_replace
1965
    100.times do
1966
      samples = [2, 1, 0].sample(2, replace: true)
1967
      samples.each{|sample|
1968
        assert([0, 1, 2].include?(sample))
1969
      }
1970
    end
1971

  
1972
    srand(0)
1973
    a = (1..18).to_a
1974
    (0..20).each do |n|
1975
      100.times do
1976
        b = a.sample(n, replace: true)
1977
        assert_equal(n, b.size)
1978
        assert_equal(a, (a | b).sort)
1979
        assert_equal(b.sort.uniq, (a & b).sort)
1980
      end
1981

  
1982
      h = Hash.new(0)
1983
      1000.times do
1984
        a.sample(n, replace: true).each {|x| h[x] += 1 }
1985
      end
1986
      assert_operator(h.values.min * 2, :>=, h.values.max) if n != 0
1987
    end
1988

  
1989
    assert_raise(ArgumentError, '[ruby-core:23374]') {[1, 2].sample(-1, replace: true)}
1990

  
1991
    gen = Random.new(0)
1992
    srand(0)
1993
    a = (1..18).to_a
1994
    (0..20).each do |n|
1995
      100.times do
1996
        assert_equal(a.sample(n, replace: true), a.sample(n, replace: true, random: gen))
1997
      end
1998
    end
1910 1999
  end
1911 2000

  
1912 2001
  def test_cycle