python注解@lru_cache实现记忆化搜索


python注解@lru_cache实现记忆化搜索

注解作用

  注解不同于注释,注解是对某一属性或方法的修饰,可以认为是对其附加一种特性,对于本身代码有影响,而注释只是为了便于读者理解而添加的说明性文字,两者有本质的不同。   我们仅对注解的作用给出简单例子,如下。

class Animal:

    cnt = 0

    def __init__(self, name: str):
        self.name = name
        Animal.cnt += 1

    def count() -> int:
        return Animal.cnt

#def count()会报错,无法执行,我们想让count作为类的静态方法,但python解释器会认为count()少传入了对象自身self而错误解释
  这里我们需要使用注解的方式将count定义为类的静态方法。这里注意,在class下默认定义的属性为类的静态属性,而不是对象的属性,对象的属性需要利用self.来完成(一般是在构造函数中实现)。
class Animal:

    cnt = 0

    def __init__(self, name: str):
        self.name = name
        Animal.cnt += 1

    @staticmethod
    def count() -> int:
        return Animal.cnt


a = Animal("cat")
b = Animal("dog")
Animal.count()
Out[5]: 2
  @lru_cache注解可以对装饰的函数进行缓存,LRU(Least Recently Used)本意指最近最少使用算法,在这里仅表示一种缓存策略,本质上还是缓存结果,换句话说,使用该注解修饰的函数往往用在需要递归求解的问题中,可以对中间步骤进行缓存来提高效率。

效率分析

  记忆化搜索对效率的提升是巨大的(递归过程中会存在巨量的重复的函数去计算一个已经得到值,因为会有极大的时间浪费),对于常年使用C++等语言写算法的同学早已习惯了各种记忆化搜索的方式。通常记忆化搜索需要由程序员本身来分配空间并定义状态,这样的操作较为繁琐,如果有一种方式能够自动的记录已搜索的结果就再好不过了,本文的核心python注解@lru_cache正是实现这样一种功能。

//在C++中传统的记忆化搜索
//引用自OI WIKI 记忆化搜索部分
// C++ Version
int dfs(int i, int j, int k) {
  // 判断边界条件
  if (mem[i][j][k] != -1) return mem[i][j][k];
  return mem[i][j][k] = dfs(i + 1, j + 1, k - a[j]) + dfs(i + 1, j, k);
}
int main() {
  memset(mem, -1, sizeof(mem));
  // 读入部分略去
  cout << dfs(1, 0, 0) << endl;
}

//模板
int g[MAXN];
int f(状态参数) {
  if (g[规模] != 无效数值) return g[规模];
  if (终止条件) return 最小子问题解;
  g[规模] = f(缩小规模);
  return g[规模];
}
int main() {
  // ...
  memset(g, 无效数值, sizeof(g));
  // ...
}
  我们尝试解决经典问题,求斐波那契数列,对比两者的效率。不难给出,在使用注解前的时间复杂度[latex]O (2^{n})[/latex],而在使用注解进行记忆化搜索后,时间复杂度为[latex]O ( n )[/latex],空间复杂度同样为[latex]O ( n )[/latex]。btw,在解决类似的线性递推问题,还可以通过矩阵快速幂的方式进一步压缩空间。 ### 使用注解前
#jupyter
from functools import *
import time
def fibonacci(n: int) -> int:
    if n <= 1:
        return 1
    else:
        return fibonacci(n - 2) + fibonacci(n - 1)
def dfs(n: int) -> float:
    t0 = time.perf_counter()
    fibonacci(n)
    t1 = time.perf_counter()
    return t1 - t0


dfs(30)
Out[4]: 0.2971243000000001
dfs(35)
Out[5]: 3.7380671000000003
dfs(40)
Out[6]: 40.5477137
### 使用注解后
#jupyter
from functools import *
import time
@lru_cache(None)
def fibonacci(n: int) -> int:
    if n <= 1:
        return 1
    else:
        return fibonacci(n - 2) + fibonacci(n - 1)
def dfs(n: int) -> float:
    t0 = time.perf_counter()
    fibonacci(n)
    t1 = time.perf_counter()
    return t1 - t0


dfs(40)
Out[4]: 1.6099999999852344e-05
dfs(100)
Out[5]: 2.459999999970819e-05
dfs(500)
Out[6]: 0.00028539999999921406

@lru_cache的用法

  lru_cache(maxsize=128, typed=False)有两个参数,第一个参数表明缓存参数个数,默认为128个,若注明None则表示无限制(但系统并不支持递归深度过大),如果设置typed=True,则不同参数类型的调用将分别缓存,例如f(3)和f(3.0)视为不同的缓存情况。   在递归函数前添加@lru_cache(None)即可,需要注明from functools import *,或在import functools的基础上使用@functools.lru_cache(None),如下示例。

  题源自力扣第276周赛T3 > 给你一个下标从 0 开始的二维整数数组 questions ,其中 questions[i] = [pointsi, brainpoweri] 。 > 这个数组表示一场考试里的一系列题目,你需要 按顺序 (也就是从问题 0 开始依次解决),针对每个问题选择 解决 或者 跳过 操作。解决问题 i 将让你 获得 pointsi 的分数,但是你将 无法 解决接下来的 brainpoweri 个问题(即只能跳过接下来的 brainpoweri 个问题)。如果你跳过问题 i ,你可以对下一个问题决定使用哪种操作。 > 比方说,给你 questions = [[3, 2], [4, 3], [4, 4], [2, 5]] : > 如果问题 0 被解决了, 那么你可以获得 3 分,但你不能解决问题 1 和 2 。 > 如果你跳过问题 0 ,且解决问题 1 ,你将获得 4 分但是不能解决问题 2 和 3 。 > 请你返回这场考试里你能获得的 最高 分数。 > 提示: > 1 <= questions.length <= 105 > questions[i].length == 2 > 1 <= pointsi, brainpoweri <= 105 > 来源:力扣(LeetCode) > 链接:https://leetcode-cn.com/problems/solving-questions-with-brainpower > 著作权归领扣网络所有。商业转载请联系官方授权,非商业转载请注明出处。

class Solution:
    def mostPoints(self, questions: List[List[int]]) -> int:

        @functools.lru_cache(None)
        def solve(t=0) :
            if t >= len(questions) :
                return 0
            points, brainpower = questions[t]
            return max(points+solve(t+brainpower+1), solve(t+1))

        return solve()

附录:关于@lru_cache注解源码

def lru_cache(maxsize=128, typed=False):
    """Least-recently-used cache decorator.

    If *maxsize* is set to None, the LRU features are disabled and the cache
    can grow without bound.

    If *typed* is True, arguments of different types will be cached separately.
    For example, f(3.0) and f(3) will be treated as distinct calls with
    distinct results.

    Arguments to the cached function must be hashable.

    View the cache statistics named tuple (hits, misses, maxsize, currsize)
    with f.cache_info().  Clear the cache and statistics with f.cache_clear().
    Access the underlying function with f.__wrapped__.

    See:  http://en.wikipedia.org/wiki/Cache_algorithms#Least_Recently_Used

    """

    # Users should only access the lru_cache through its public API:
    #       cache_info, cache_clear, and f.__wrapped__
    # The internals of the lru_cache are encapsulated for thread safety and
    # to allow the implementation to change (including a possible C version).

    # Early detection of an erroneous call to @lru_cache without any arguments
    # resulting in the inner function being passed to maxsize instead of an
    # integer or None.
    if maxsize is not None and not isinstance(maxsize, int):
        raise TypeError('Expected maxsize to be an integer or None')

    def decorating_function(user_function):
        wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo)
        return update_wrapper(wrapper, user_function)

    return decorating_function

def _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo):
    # Constants shared by all lru cache instances:
    sentinel = object()          # unique object used to signal cache misses
    make_key = _make_key         # build a key from the function arguments
    PREV, NEXT, KEY, RESULT = 0, 1, 2, 3   # names for the link fields

    cache = {}
    hits = misses = 0
    full = False
    cache_get = cache.get    # bound method to lookup a key or return None
    cache_len = cache.__len__  # get cache size without calling len()
    lock = RLock()           # because linkedlist updates aren't threadsafe
    root = []                # root of the circular doubly linked list
    root[:] = [root, root, None, None]     # initialize by pointing to self

    if maxsize == 0:

        def wrapper(*args, **kwds):
            # No caching -- just a statistics update after a successful call
            nonlocal misses
            result = user_function(*args, **kwds)
            misses += 1
            return result

    elif maxsize is None:

        def wrapper(*args, **kwds):
            # Simple caching without ordering or size limit
            nonlocal hits, misses
            key = make_key(args, kwds, typed)
            result = cache_get(key, sentinel)
            if result is not sentinel:
                hits += 1
                return result
            result = user_function(*args, **kwds)
            cache[key] = result
            misses += 1
            return result

    else:

        def wrapper(*args, **kwds):
            # Size limited caching that tracks accesses by recency
            nonlocal root, hits, misses, full
            key = make_key(args, kwds, typed)
            with lock:
                link = cache_get(key)
                if link is not None:
                    # Move the link to the front of the circular queue
                    link_prev, link_next, _key, result = link
                    link_prev[NEXT] = link_next
                    link_next[PREV] = link_prev
                    last = root[PREV]
                    last[NEXT] = root[PREV] = link
                    link[PREV] = last
                    link[NEXT] = root
                    hits += 1
                    return result
            result = user_function(*args, **kwds)
            with lock:
                if key in cache:
                    # Getting here means that this same key was added to the
                    # cache while the lock was released.  Since the link
                    # update is already done, we need only return the
                    # computed result and update the count of misses.
                    pass
                elif full:
                    # Use the old root to store the new key and result.
                    oldroot = root
                    oldroot[KEY] = key
                    oldroot[RESULT] = result
                    # Empty the oldest link and make it the new root.
                    # Keep a reference to the old key and old result to
                    # prevent their ref counts from going to zero during the
                    # update. That will prevent potentially arbitrary object
                    # clean-up code (i.e. __del__) from running while we're
                    # still adjusting the links.
                    root = oldroot[NEXT]
                    oldkey = root[KEY]
                    oldresult = root[RESULT]
                    root[KEY] = root[RESULT] = None
                    # Now update the cache dictionary.
                    del cache[oldkey]
                    # Save the potentially reentrant cache[key] assignment
                    # for last, after the root and links have been put in
                    # a consistent state.
                    cache[key] = oldroot
                else:
                    # Put result in a new link at the front of the queue.
                    last = root[PREV]
                    link = [last, root, key, result]
                    last[NEXT] = root[PREV] = cache[key] = link
                    # Use the cache_len bound method instead of the len() function
                    # which could potentially be wrapped in an lru_cache itself.
                    full = (cache_len() >= maxsize)
                misses += 1
            return result

    def cache_info():
        """Report cache statistics"""
        with lock:
            return _CacheInfo(hits, misses, maxsize, cache_len())

    def cache_clear():
        """Clear the cache and cache statistics"""
        nonlocal hits, misses, full
        with lock:
            cache.clear()
            root[:] = [root, root, None, None]
            hits = misses = 0
            full = False

    wrapper.cache_info = cache_info
    wrapper.cache_clear = cache_clear
    return wrapper

try:
    from _functools import _lru_cache_wrapper
except ImportError:
    pass

[toc]


文章作者: Commander
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 Commander !
  目录