bsearch.patch

Yusuke Endoh, 07/20/2011 11:21 AM

Download (12.9 KB)

View differences:

array.c
2158 2158
    return ary;
2159 2159
}
2160 2160

  
2161
/*
2162
 *  call-seq:
2163
 *     ary.bsearch {|x| block }  -> elem
2164
 */
2165

  
2166
static VALUE
2167
rb_ary_bsearch(VALUE ary)
2168
{
2169
    long low = 0, high = RARRAY_LEN(ary), mid;
2170
    int smaller, satisfied = 0;
2171
    VALUE v, val;
2172

  
2173
    while (low < high) {
2174
	mid = low + ((high - low) / 2);
2175
	val = rb_ary_entry(ary, mid);
2176
	v = rb_yield(val);
2177
	if (FIXNUM_P(v)) {
2178
	    if (FIX2INT(v) == 0) return val;
2179
	    smaller = FIX2INT(v) < 0;
2180
	}
2181
	else if (v == Qtrue) {
2182
	    satisfied = 1;
2183
	    smaller = 1;
2184
	}
2185
	else if (v == Qfalse || v == Qnil) {
2186
	    smaller = 0;
2187
	}
2188
	else if (rb_obj_is_kind_of(v, rb_cNumeric)) {
2189
	    switch (rb_cmpint(rb_funcall(v, id_cmp, 1, INT2FIX(0)), v, INT2FIX(0)) < 0) {
2190
		case 0: return val;
2191
		case 1: smaller = 1;
2192
		case -1: smaller = 0;
2193
	    }
2194
	}
2195
	else {
2196
	    smaller = RTEST(v);
2197
	}
2198
	if (smaller) {
2199
	    high = mid;
2200
	}
2201
	else {
2202
	    low = mid + 1;
2203
	}
2204
    }
2205
    if (low == RARRAY_LEN(ary)) return Qnil;
2206
    if (!satisfied) return Qnil;
2207
    return rb_ary_entry(ary, low);
2208
}
2209

  
2161 2210

  
2162 2211
static VALUE
2163 2212
sort_by_i(VALUE i)
......
4745 4794
    rb_define_method(rb_cArray, "take_while", rb_ary_take_while, 0);
4746 4795
    rb_define_method(rb_cArray, "drop", rb_ary_drop, 1);
4747 4796
    rb_define_method(rb_cArray, "drop_while", rb_ary_drop_while, 0);
4797
    rb_define_method(rb_cArray, "bsearch", rb_ary_bsearch, 0);
4748 4798

  
4749 4799
    id_cmp = rb_intern("<=>");
4750 4800
    sym_random = ID2SYM(rb_intern("random"));
range.c
13 13
#include "ruby/encoding.h"
14 14
#include "internal.h"
15 15

  
16
#ifdef HAVE_FLOAT_H
17
#include <float.h>
18
#endif
19
#include <math.h>
20

  
16 21
VALUE rb_cRange;
17 22
static ID id_cmp, id_succ, id_beg, id_end, id_excl;
18 23

  
......
436 441
    return range;
437 442
}
438 443

  
444
/*
445
 *  call-seq:
446
 *     rng.bsearch {|obj| block }  -> element
447
 *
448
 *  Finds a value in range which meets the given condition in O(n log n)
449
 *  where n = (rng.begin - rng.end), by using binary search.
450
 *
451
 *  The given block receives a current value, determines if it meets the
452
 *  condition and controls search.
453
 *  When the condition is satisfied and you want to stop search, the block
454
 *  should return zero, and then this method return the value immediately.
455
 *  When the condition is satisfied and you want to find minimum bound,
456
 *  the block should return true.  When the condition is not satisfied and
457
 *  the current value is smaller than wanted, the block should return false,
458
 *  nil or an integer greater than zero. When the condition is not satisfied
459
 *  and the current value is larger than wanted, the block should return an
460
 *  integer less than zero.
461
 *  Unless the block returns zero, the search will continue until a minimum
462
 *  bound is found or no match is found.  Returns the minimum bound if any,
463
 *  or returns nil when no match is found.
464
 *
465
 *  The block must be monotone; there must be two values a and b so that
466
 *  the block returns:
467
 *  - false, nil or an integer greater than zero for all x of [begin of
468
 *    range, a), and
469
 *  - zero or true for all x of [a, b), and
470
 *  - an integer less than zero for all x of [b, end of range).
471
 *  If the block is not monotone, the result is unspecified.
472
 *
473
 *  This method takes O(n log n), but it is unspecified which value is
474
 *  actually picked up at each iteration.
475
 *
476
 *     ary = [0, 4, 7, 10, 12]
477
 *     (0...ary.size).bsearch {|i| ary[i] >= 4 } #=> 1
478
 *     (0...ary.size).bsearch {|i| ary[i] >= 6 } #=> 2
479
 *     (0...ary.size).bsearch {|i| ary[i] >= 8 } #=> 3
480
 *     (0...ary.size).bsearch {|i| ary[i] >= 100 } #=> nil
481
 *
482
 *     (0.0...Float::INFINITY).bsearch {|x| Math.log(x) >= 0 } #=> 1.0
483
 *
484
 *     ary = [0, 100, 100, 100, 200]
485
 *     (0..4).bsearch {|i| 100 - i } #=> 1, 2 or 3
486
 *     (0..4).bsearch {|i| 300 - i } #=> nil
487
 *     (0..4).bsearch {|i|  50 - i } #=> nil
488
 */
489

  
490
static VALUE
491
range_bsearch(VALUE range)
492
{
493
    VALUE beg, end;
494
    int smaller, satisfied = 0;
495

  
496
#define BSEARCH_CHECK(val) \
497
    do { \
498
	VALUE v = rb_yield(val); \
499
	if (FIXNUM_P(v)) { \
500
	    if (FIX2INT(v) == 0) return val; \
501
	    smaller = FIX2INT(v) < 0; \
502
	} \
503
	else if (v == Qtrue) { \
504
	    satisfied = 1; \
505
	    smaller = 1; \
506
	} \
507
	else if (v == Qfalse || v == Qnil) { \
508
	    smaller = 0; \
509
	} \
510
	else if (rb_obj_is_kind_of(v, rb_cNumeric)) { \
511
	    switch (rb_cmpint(rb_funcall(v, id_cmp, 1, INT2FIX(0)), v, INT2FIX(0)) < 0) { \
512
		case 0: return val; \
513
		case 1: smaller = 1; \
514
		case -1: smaller = 0; \
515
	    } \
516
	} \
517
	else { \
518
	    smaller = RTEST(v); \
519
	} \
520
    } while (0)
521

  
522
    beg = RANGE_BEG(range);
523
    end = RANGE_END(range);
524

  
525
    if (FIXNUM_P(beg) && FIXNUM_P(end)) {
526
	long low = FIX2LONG(beg);
527
	long high = FIX2LONG(end);
528
	long mid, org_high;
529
	if (EXCL(range)) high--;
530
	org_high = high;
531

  
532
	while (low < high) {
533
	    mid = low + ((high - low) / 2);
534
	    BSEARCH_CHECK(INT2FIX(mid));
535
	    if (smaller) {
536
		high = mid;
537
	    }
538
	    else {
539
		low = mid + 1;
540
	    }
541
	}
542
	if (low == org_high) {
543
	    BSEARCH_CHECK(INT2FIX(low));
544
	    if (!smaller) return Qnil;
545
	}
546
	if (!satisfied) return Qnil;
547
	return INT2FIX(low);
548
    }
549
    else if (TYPE(beg) == T_FLOAT || TYPE(end) == T_FLOAT) {
550
	double low  = RFLOAT_VALUE(rb_Float(beg));
551
	double high = RFLOAT_VALUE(rb_Float(end));
552
	double mid, org_high, last_found_key = 0.0;
553
	int count, found = 0;
554
#ifdef FLT_RADIX
555
#ifdef DBL_MANT_DIG
556
#define COUNT (((FLT_RADIX) - 1) * (DBL_MANT_DIG + DBL_MAX_EXP) + 100)
557
#else
558
#define count (53 + 1023 + 100)
559
#endif
560
#else
561
#define count (53 + 1023 + 100)
562
#endif
563
	if (isinf(high) && high > 0) {
564
	    double nhigh = 1.0, inc;
565
	    if (nhigh < low) nhigh = low;
566
	    count = COUNT;
567
	    while (count >= 0 && !isinf(nhigh)) {
568
		BSEARCH_CHECK(DBL2NUM(nhigh));
569
		if (smaller) break;
570
		high = nhigh;
571
		nhigh *= 2;
572
		count--;
573
	    }
574
	    if (isinf(nhigh) || count < 0) {
575
		inc = high / 2;
576
		count = COUNT;
577
		while (count >= 0 && inc > 0) {
578
		    nhigh = high + inc;
579
		    if (!isinf(nhigh)) {
580
			BSEARCH_CHECK(DBL2NUM(nhigh));
581
			if (smaller) {
582
			    low = high;
583
			    high = nhigh;
584
			    goto binsearch;
585
			}
586
			else {
587
			    high = nhigh;
588
			}
589
		    }
590
		    inc /= 2;
591
		    count--;
592
		}
593
		high *= 2; /* generate infinity */
594
		if (isinf(high) && !EXCL(range)) {
595
		    BSEARCH_CHECK(DBL2NUM(high));
596
		    if (!satisfied) return Qnil;
597
		    if (smaller) return DBL2NUM(high);
598
		}
599
		return Qnil;
600
	    }
601
	    high = nhigh;
602
	}
603
	if (isinf(low) && low < 0) {
604
	    double nlow = -1.0, dec;
605
	    if (nlow > high) nlow = high;
606
	    count = COUNT;
607
	    while (count >= 0 && !isinf(nlow)) {
608
		BSEARCH_CHECK(DBL2NUM(nlow));
609
		if (!smaller) break;
610
		low = nlow;
611
		nlow *= 2;
612
		count--;
613
	    }
614
	    if (isinf(nlow) || count < 0) {
615
		dec = low / 2;
616
		count = COUNT;
617
		while (count >= 0 && dec < 0) {
618
		    nlow = low + dec;
619
		    if (!isinf(nlow)) {
620
			BSEARCH_CHECK(DBL2NUM(nlow));
621
			if (!smaller) {
622
			    high = low;
623
			    low = nlow;
624
			    goto binsearch;
625
			}
626
			else {
627
			    low = nlow;
628
			}
629
		    }
630
		    dec /= 2;
631
		    count--;
632
		}
633
		nlow = low * 2; /* generate infinity */
634
		if (isinf(nlow)) {
635
		    BSEARCH_CHECK(DBL2NUM(nlow));
636
		    if (!satisfied) return Qnil;
637
		    if (smaller) return DBL2NUM(nlow);
638
		}
639
		if (!satisfied) return Qnil;
640
		return DBL2NUM(low);
641
	    }
642
	    low = nlow;
643
	}
644

  
645
    binsearch:
646
	org_high = high;
647
	count = COUNT;
648
	while (low < high && count >= 0) {
649
	    mid = low + ((high - low) / 2);
650
	    BSEARCH_CHECK(DBL2NUM(mid));
651
	    if (smaller) {
652
		found = 1;
653
		last_found_key = high;
654
		high = mid;
655
	    }
656
	    else {
657
		low = mid;
658
	    }
659
	    count--;
660
	}
661
	BSEARCH_CHECK(DBL2NUM(low));
662
	if (!smaller) {
663
	    BSEARCH_CHECK(DBL2NUM(high));
664
	    if (!smaller) {
665
		if (found) {
666
		    low = last_found_key;
667
		}
668
		else {
669
		    return Qnil;
670
		}
671
	    }
672
	    low = high;
673
	}
674
	if (!satisfied) return Qnil;
675
	if (EXCL(range) && low >= org_high) return Qnil;
676
	return DBL2NUM(low);
677
#undef COUNT
678
    }
679
    else if (!NIL_P(rb_check_to_integer(beg, "to_int")) &&
680
	     !NIL_P(rb_check_to_integer(end, "to_int"))) {
681
	VALUE low = beg;
682
	VALUE high = end;
683
	VALUE mid, org_high;
684
	if (EXCL(range)) high = rb_funcall(high, '-', 1, INT2FIX(1));
685
	org_high = high;
686

  
687
	while (rb_cmpint(rb_funcall(low, id_cmp, 1, high), low, high) < 0) {
688
	    mid = rb_funcall(rb_funcall(high, '+', 1, low), '/', 1, INT2FIX(2));
689
	    BSEARCH_CHECK(mid);
690
	    if (smaller) {
691
		high = mid;
692
	    }
693
	    else {
694
		low = rb_funcall(mid, '+', 1, INT2FIX(1));
695
	    }
696
	}
697
	if (rb_equal(low, org_high)) {
698
	    BSEARCH_CHECK(low);
699
	    if (!smaller) return Qnil;
700
	}
701
	if (!satisfied) return Qnil;
702
	return low;
703
    }
704
    else {
705
	rb_raise(rb_eTypeError, "can't do binary search for %s", rb_obj_classname(beg));
706
    }
707
    return range;
708
    
709
}
710

  
439 711
static VALUE
440 712
each_i(VALUE v, void *arg)
441 713
{
......
1035 1307
    rb_define_method(rb_cRange, "hash", range_hash, 0);
1036 1308
    rb_define_method(rb_cRange, "each", range_each, 0);
1037 1309
    rb_define_method(rb_cRange, "step", range_step, -1);
1310
    rb_define_method(rb_cRange, "bsearch", range_bsearch, 0);
1038 1311
    rb_define_method(rb_cRange, "begin", range_begin, 0);
1039 1312
    rb_define_method(rb_cRange, "end", range_end, 0);
1040 1313
    rb_define_method(rb_cRange, "first", range_first, -1);
test/ruby/test_range.rb
347 347
      assert !x.eql?(z)
348 348
    }
349 349
  end
350

  
351
  def test_bsearch_for_fixnum
352
    ary = [3, 4, 7, 9, 12]
353
    assert_equal(0, (0...ary.size).bsearch {|i| ary[i] >= 2 })
354
    assert_equal(1, (0...ary.size).bsearch {|i| ary[i] >= 4 })
355
    assert_equal(2, (0...ary.size).bsearch {|i| ary[i] >= 6 })
356
    assert_equal(3, (0...ary.size).bsearch {|i| ary[i] >= 8 })
357
    assert_equal(4, (0...ary.size).bsearch {|i| ary[i] >= 10 })
358
    assert_equal(nil, (0...ary.size).bsearch {|i| ary[i] >= 100 })
359
    assert_equal(0, (0...ary.size).bsearch {|i| true })
360
    assert_equal(nil, (0...ary.size).bsearch {|i| false })
361

  
362
    ary = [0, 100, 100, 100, 200]
363
    assert_equal(1, (0...ary.size).bsearch {|i| ary[i] >= 100 })
364
  end
365

  
366
  def test_bsearch_for_float
367
    inf = Float::INFINITY
368
    assert_in_delta(10.0, (0.0...100.0).bsearch {|x| x > 0 && Math.log(x / 10) >= 0 }, 0.0001)
369
    assert_in_delta(10.0, (0.0...inf).bsearch {|x| x > 0 && Math.log(x / 10) >= 0 }, 0.0001)
370
    assert_in_delta(-10.0, (-inf..100.0).bsearch {|x| x >= 0 || Math.log(-x / 10) < 0 }, 0.0001)
371
    assert_in_delta(10.0, (-inf..inf).bsearch {|x| x > 0 && Math.log(x / 10) >= 0 }, 0.0001)
372
    assert_equal(nil, (-inf..5).bsearch {|x| x > 0 && Math.log(x / 10) >= 0 }, 0.0001)
373

  
374
    assert_in_delta(10.0, (-inf.. 10).bsearch {|x| x > 0 && Math.log(x / 10) >= 0 }, 0.0001)
375
    assert_equal(nil,     (-inf...10).bsearch {|x| x > 0 && Math.log(x / 10) >= 0 }, 0.0001)
376

  
377
    assert_equal(nil, (-inf..inf).bsearch { false })
378
    assert_equal(-inf, (-inf..inf).bsearch { true })
379

  
380
    assert_equal(inf, (0..inf).bsearch {|x| x == inf })
381
    assert_equal(nil, (0...inf).bsearch {|x| x == inf })
382

  
383
    v = (-inf..0).bsearch {|x| x != -inf }
384
    assert_operator(-Float::MAX, :>=, v)
385
    assert_operator(-inf, :<, v)
386

  
387
    v = (0.0..1.0).bsearch {|x| x > 0 } # the nearest positive value to 0.0
388
    assert_in_delta(0, v, 0.0001)
389
    assert_operator(0, :<, v)
390
    assert_equal(0.0, (-1.0..0.0).bsearch {|x| x >= 0 })
391
    assert_equal(nil, (-1.0...0.0).bsearch {|x| x >= 0 })
392

  
393
    v = (0..Float::MAX).bsearch {|x| x >= Float::MAX }
394
    assert_in_delta(Float::MAX, v)
395
    assert_equal(nil, v.infinite?)
396

  
397
    v = (0..inf).bsearch {|x| x >= Float::MAX }
398
    assert_in_delta(Float::MAX, v)
399
    assert_equal(nil, v.infinite?)
400

  
401
    v = (-Float::MAX..0).bsearch {|x| x > -Float::MAX }
402
    assert_operator(-Float::MAX, :<, v)
403
    assert_equal(nil, v.infinite?)
404

  
405
    v = (-inf..0).bsearch {|x| x >= -Float::MAX }
406
    assert_in_delta(-Float::MAX, v)
407
    assert_equal(nil, v.infinite?)
408
  end
409

  
410
  def test_bsearch_for_bignum
411
    bignum = 2**100
412
    ary = [3, 4, 7, 9, 12]
413
    assert_equal(bignum + 0, (bignum...bignum+ary.size).bsearch {|i| ary[i - bignum] >= 2 })
414
    assert_equal(bignum + 1, (bignum...bignum+ary.size).bsearch {|i| ary[i - bignum] >= 4 })
415
    assert_equal(bignum + 2, (bignum...bignum+ary.size).bsearch {|i| ary[i - bignum] >= 6 })
416
    assert_equal(bignum + 3, (bignum...bignum+ary.size).bsearch {|i| ary[i - bignum] >= 8 })
417
    assert_equal(bignum + 4, (bignum...bignum+ary.size).bsearch {|i| ary[i - bignum] >= 10 })
418
    assert_equal(nil, (bignum...bignum+ary.size).bsearch {|i| ary[i - bignum] >= 100 })
419
    assert_equal(bignum + 0, (bignum...bignum+ary.size).bsearch {|i| true })
420
    assert_equal(nil, (bignum...bignum+ary.size).bsearch {|i| false })
421

  
422
    assert_raise(TypeError) { ("a".."z").bsearch {} }
423
  end
350 424
end