[算法] Python 中的二分算法

初始化于: 2013-08-30

上次修改: 2013-09-11

维基: Binary search algorithm

时间复杂度: O(log n)

基本条件: 排序后数组

在其它语言中, 如果low + high的和大于Integer的最大值,比如2 ** 31 - 1, 计算便会发生溢出, 使它成为一个负数,然后被2除时结果仍为负数. 在Java语言中, 这个Bug导致一个ArrayIndexOutOfBoundsException异常被抛出, 而在C语言中, 你会得到一个无法预测的越界的数组下标. 推荐的解决方法是修改中间值的计算过程, 方法之一是用减法而不是加法——来实现:mid = low + ((high - low) / 2); 或者,如果你想炫耀一下自己掌握的移位运算的知识, 可以使用更快的移位运算操作, 在Python中是mid = (low + high) >> 1, Java中是int mid = (low + high) >>> 1

之前大于小于等于有点晕, 通过看Python源码, 已经不纠结于大于小于等于的问题了. 其实就两种方式, 就可以适应问题不同的变化.

文档: insort前缀为插入, bisect前缀为返回索引

源码: bisect.py

关键部分:

初始化(这里要注意索引hi为搜索范围+1).

if lo < 0:
    raise ValueError('lo must be non-negative')
if hi is None:
    hi = len(a)

insort_left | bisect_left

while lo < hi:
    mid = (lo+hi)//2 # 向下取整 有的语言需防溢出: mid = lo + ((hi-lo)>>1)
    if a[mid] < x: lo = mid+1
    else: hi = mid

insort_right | bisect_right —— 这里要注意返回的索引位置并非等于x的. 是最右的等于x的索引加1

while lo < hi:
    mid = (lo+hi)//2
    if x < a[mid]: hi = mid
    else: lo = mid+1

最左最右的主要区别在于x == a[mid]时, 是hi = mid[高变低, 左趋], 还是lo = mid+1[低变高, 右趋], 其余都是一样的. 不直接返回mid是因为并不能确定mid指向的是左还是右还是中间.

>>> a = [1, 2, 2, 4, 5, 6, 7, 8, 9]
>>> from bisect import bisect_left, bisect_right
>>> bisect_left(a, 0) # 小于所有数时, 返回0
0
>>> bisect_right(a, 0)
0
>>> bisect_left(a, 2)  # 最左的等于x的索引
1
>>> bisect_right(a, 2)  # 最右的等于x的索引+1
3
>>> bisect_left(a, 3)  # 不存在时, 返回第一个刚好大于x的位置
3
>>> bisect_right(a, 3)
3
>>> bisect_left(a, 10)  # 大于所有数时, 返回值为len(a)
9
>>> bisect_right(a, 10)
9

分析:

a = [1, 2, 2, 4, 5, 6, 7, 8, 9]

假设要寻找 x = 2 应该插入的位置
初始化:

lo = 0
hi = 9
mid = 4
a[mid] = 5

x < a[mid]
hi = mid

lo = 0
hi = 4
mid = 2
a[mid] = 2

x = a[mid]
left        | right
hi = mid    | lo = mid+1
lo = 0      | lo = 3
hi = 2      | hi = 4
mid = 1     | mid = 3
a[mid] = 2  | a[mid] = 4
x == a[mid] | x < a[mid]
hi = mid    | hi = mid

lo = 0      | lo = 3
hi = 1      | hi = 3
mid = 0     | lo == hi
a[mid] = 1  | 返回lo = 3
x > a[mid]  |
lo = mid+1  |

lo = 1      |
hi = 1      |
lo == hi    |
返回lo = 1  |

应用: 搜索已排序列表

def index(a, x):
    'Locate the leftmost value exactly equal to x'
    '二分查找返回key(可能有重复)第一次出现的下标'
    i = bisect_left(a, x)
    if i != len(a) and a[i] == x:
        return i
    raise ValueError

def find_lt(a, x):
    'Find rightmost value less than x'
    '二分查找返回刚好小于key的元素下标'
    i = bisect_left(a, x)
    if i: # i不等于0
        return a[i-1]
    raise ValueError

def find_le(a, x):
    'Find rightmost value less than or equal to x'
    i = bisect_right(a, x)
    if i: # 如果要找最右的等于x的, 条件为: if i and a[i-1] == x
        return a[i-1]
    raise ValueError

def find_gt(a, x):
    'Find leftmost value greater than x'
    '二分查找返回刚好大于x的元素下标'
    i = bisect_right(a, x)
    if i != len(a): # 如果i等于len(a), 则不存在大于x的值
        return a[i]
    raise ValueError

def find_ge(a, x):
    'Find leftmost item greater than or equal to x'
    i = bisect_left(a, x)
    if i != len(a):
        return a[i]
    raise ValueError

下面这个实例看起来很有用: numeric table lookups

>>> def grade(score, breakpoints=[60, 70, 80, 90], grades='FDCBA'):
    i = bisect_right(breakpoints, score)
    return grades[i]

>>> [grade(score) for score in [33, 99, 77, 70, 89, 90, 100]]
['F', 'A', 'C', 'C', 'B', 'A', 'A']

二分查找方法并没有类似sort方法的key参数, 对于已排序的列表, 可以单独将要查找的键值解析成列表, 索引一一对应, 搜索结果返回索引然后再通过索引在原列表中获得值.

>>> data = [('red', 5), ('blue', 1), ('yellow', 8), ('black', 0)]
>>> data.sort(key=lambda r: r[1])
>>> keys = [r[1] for r in data]  # precomputed list of keys
>>> data[bisect_left(keys, 0)]
('black', 0)
>>> data[bisect_left(keys, 1)]
('blue', 1)
>>> data[bisect_left(keys, 5)]
('red', 5)
>>> data[bisect_left(keys, 8)]
('yellow', 8)

在轮转后的有序数组上应用二分查找法

问题描述: 如果有序序列发生偏移即把序列的后面一部分截取放在前面, 比如:

11 13 1 2 4 7 9

此时再给定一个数, 查找其在序列中是否存在(返回其位置), 请问如何实现?

思路: 考虑顺序情况下的做法, 再考虑轮转后的特殊情况

存在重复元素时会返回最左的索引:

def binsearch_left(a, key, lo=0, hi=None):
    if lo < 0:
        raise ValueError('lo must be non-negative')
    if hi is None:
        hi = len(a)
    while lo < hi:
        mid = (lo + hi) // 2
        if a[mid] < key:
            if a[mid] < a[lo] and a[hi-1] < key:
                hi = mid
            else:
                lo = mid+1
        else:
            if a[hi-1] < a[mid] and key < a[lo]:
                lo = mid+1
            else:
                hi = mid
    return lo if lo != len(a) and a[lo] == key else -1

arr = [7, 11, 13, 17, 2, 3, 5]
print([binsearch_left(arr, x) for x in arr])  # [0, 1, 2, 3, 4, 5, 6]
arr = [11, 13, 1, 2, 4, 7, 9]
print([binsearch_left(arr, x) for x in arr])  # [0, 1, 2, 3, 4, 5, 6]
arr = [11, 13, 15, 15, 1, 1, 1, 2, 2, 4, 7, 9]
print([binsearch_left(arr, x) for x in arr])  # [0, 1, 2, 2, 4, 4, 4, 7, 7, 9, 10, 11]

存在重复元素时会返回最右的索引:

def binsearch(a, key, lo=0, hi=None):
    if lo < 0:
        raise ValueError('lo must be non-negative')
    if hi is None:
        hi = len(a)
    while lo < hi:
        mid = (lo + hi) // 2
        if key < a[mid]:
            if a[hi-1] < a[mid] and key < a[lo]:
                lo = mid+1
            else:
                hi = mid
        elif a[mid] < key:
            if a[mid] < a[lo] and a[hi-1] < key:
                hi = mid
            else:
                lo = mid+1
        else:
            return mid
    return lo if lo != len(a) and a[lo] == key else -1

arr = [7, 11, 13, 17, 2, 3, 5]
print([binsearch(arr, x) for x in arr])  # [0, 1, 2, 3, 4, 5, 6]
arr = [11, 13, 1, 2, 4, 7, 9]
print([binsearch(arr, x) for x in arr])  # [0, 1, 2, 3, 4, 5, 6]
arr = [11, 13, 15, 15, 1, 1, 1, 2, 2, 4, 7, 9]
print([binsearch(arr, x) for x in arr])  # [0, 1, 3, 3, 6, 6, 6, 8, 8, 9, 10, 11]

同上题描述, 找出序列中最小元素位置:

def binsearch_smallest(a, lo=0, hi=None):
    if lo < 0:
        raise ValueError('lo must be non-negative')
    if hi is None:
        hi = len(a) - 1 # 这里与上面不同
    while lo < hi:
        mid = (lo + hi) // 2
        if a[hi] < a[mid]:
            lo = mid+1
        else:
            hi = mid
    return lo

arr = [7, 11, 13, 13, 2, 2, 3, 5]
print(binsearch_smallest(arr))
arr = [7, 11, 13, 17, 2, 2, 2, 3, 5]
print(binsearch_smallest(arr))
arr = [7, 11, 13, 17, 2, 2, 3, 3, 5]
print(binsearch_smallest(arr))
arr = [1, 2, 3, 3, 5, 7, 11, 13, 13]
print(binsearch_smallest(arr))
arr = [11, 13, 1, 2, 4, 7, 9]
print(binsearch_smallest(arr))

相关文章:

二分查找,你真的会吗? —— 大概看了下, 没有参考博客中的实现. 因为while的判断条件情况太多种了, 反而不易理解.

为什么python标准库没有实现链表

TimeComplexity

二分查找法的实现和应用汇总

二分查找算法及变种的编码实现 —— 里面有些问题还不懂...