今天面试一个小伙问旋转数组的二分搜索,https://leetcode.cn/problems/search-in-rotated-sorted-array 我自己先写了一份代码如下,是不是觉得很丑陋,重复代码较多?

class Solution:
    def search(self, nums: List[int], target: int) -> int:
        n = len(nums)

        def search_highest():
            l, r = 0, n - 1
            while l + 1 < r:
                mid = (l + r) // 2
                if nums[mid] > nums[0]:
                    l = mid
                else:
                    r = mid
            
            if nums[l] > nums[r]:
                return l            
            return r

        def bin_search(l, r):
            if l > r:
                return -1

            while l + 1 < r:
                mid = (l + r) // 2
                if nums[mid] < target:
                    l = mid
                else:
                    r = mid
            
            if nums[l] == target:
                return l         
            if nums[r] == target:   
                return r
            return -1

        index = search_highest()
        i = bin_search(0, index)        
        return i if i >= 0 else bin_search(index+1, n-1)   


重构后的代码如下,看起来优雅多了,核心是lambda函数的巧妙应用!

class Solution:
    def search(self, nums: List[int], target: int) -> int:            
        n = len(nums)
        cmp_func = lambda mid: nums[mid] > nums[0]
        cmp_func2 = lambda mid: nums[mid] < target
        ret_func = lambda l, r: l if nums[l] > nums[r] else r
        ret_func2 = lambda l, r: l if nums[l] == target else (r if nums[r] == target else -1)

        def bin_search(l, r, cmp, ret):
            if l > r:
                return -1
                
            while l + 1 < r:
                mid = (l + r) // 2
                if cmp(mid):
                    l = mid
                else:
                    r = mid

            return ret(l, r)

        index = bin_search(0, n-1, cmp_func, ret_func)
        i = bin_search(0, index, cmp_func2, ret_func2)
        return i if i >= 0 else bin_search(index + 1, n - 1, cmp_func2, ret_func2)