Flask之上下文管理
知识储备之问题情境:
request中的参数:
- 单进程单线程
- 单进程多线程-->reqeust 会因为多个请求,数据发生错乱.--->可以基于threading.local对象
- 单进程单线程(多协程)threading.local对象做不到(因为一个线程下多个协程同享一个线程的资源)
解决办法:
自定义类似threading.local对象(支持协程)---保证多协程下数据的安全
先来看一下下面这段代码(支持多线程):
# -*- coding: utf-8 -*- """ 1288::{} """ from _thread import get_ident import threading class Local(object): def __init__(self): self.storage = {} self.get_ident = get_ident # 设置值 def set(self, k, v): # 获取线程的唯一标识 ident = self.get_ident() # 通过唯一标识去字典里面取值 origin = self.storage.get(ident) if not origin: origin = {k: v} else: origin[k] = v # 将k,v 保存到 storage中 形式如下 # { # 1023:{k,v}, # self.storage[ident] = origin 所添加的值 # 1045:{k1,v1} # 原先storage中有的值 # } self.storage[ident] = origin # 获取值 def get(self, k): ident = self.get_ident() origin = self.storage.get(ident) if not origin: return None return origin.get(k, None) # 获取一个线程对象 local_obj = Local() # 获取每一个线程的唯一标识 def task(num): local_obj.set('name',num) import time time.sleep(1) print(local_obj.get('name'),threading.current_thread().name) for i in range(20): th = threading.Thread(target=task, args=(i,), name='线程%s' % i) th.start() """ 0 线程0 1 线程1 2 线程2 5 线程5 6 线程6 3 线程3 4 线程4 10 线程10 9 线程9 11 线程11 7 线程7 13 线程13 14 线程14 17 线程17 18 线程18 15 线程15 19 线程19 8 线程8 12 线程12 16 线程16 """
再进一步,支持协程
# 首先需要安装依赖 pip3 intall gevent # gevent 依赖安装 greenlet 可以获取协程的唯一标识 # -*- coding: utf-8 -*- """ 1288::{ } """ try: # 优先用协程的 # 如果是单线程多协程,导入获取协程唯一标识的 from greenlet import getcurrent as get_ident # 协程 except ImportError: try: # 如果是多线程导入获取线程唯一标识的 from thread import get_ident except ImportError: # 如果是多线程导入获取线程唯一标识的 from _thread import get_ident # 线程 class Local(object): def __init__(self): self.storage = {} self.get_ident = get_ident # 设置值 def set(self, k, v): # 获取线程的唯一标识 ident = self.get_ident() # 通过唯一标识去字典里面取值 origin = self.storage.get(ident) if not origin: origin = {k: v} else: origin[k] = v # 将k,v 保存到 storage中 形式如下 # { # 1023:{k,v}, # self.storage[ident] = origin 所添加的值 # 1045:{k1,v1} # 原先storage中有的值 # } self.storage[ident] = origin # 获取值 def get(self, k): ident = self.get_ident() origin = self.storage.get(ident) if not origin: return None return origin.get(k, None) # 获取一个线程对象 local_obj = Local() # 获取每一个线程的唯一标识 def task(num): local_obj.set('name', num) import time time.sleep(1) print(local_obj.get('name'), threading.current_thread().name) for i in range(20): th = threading.Thread(target=task, args=(i,), name='线/协程%s' % i) th.start()
flask中实现的方式
flask中运用了面向对象的一些方法重试简化了实现方式
先补充了解面向对象的姿势:
class Foo(): # 在执行 对象.属性 = 值的时候执行,这里可以写赋值操作 def __setattr__(self,key,value): print(key,value) # 在执行 对象.属性的时候,执行, 这里可以写获取对象的属性 def __getattr__(self, item): print(item) foo = Foo() foo.x = 123 foo.x
但是还是有点问题 上面写法: 如果在 初始化操作的时候,会出现递归问题
class Foo(): def __init__(self): self.storage ={} def __setattr__(self,key,value): self.storage = {'k':'v'} print(key,value) def __getattr__(self, item): print(item) foo = Foo() foo.x = 123 foo.x """ 上述办法 会在 __setattr__ 这里产生递归 self.storage = {'k':'v'} [Previous line repeated 327 more times] RecursionError: maximum recursion depth exceeded """
解决办法
class Foo(object): def __init__(self): object.__setattr__(self, "storage", {}) # self.storage = {} def __setattr__(self, key, value): storage = self.storage storage['1024'] = {key: value} print(storage) def __getattr__(self, item): print(item) """ {'1024': {'x': 123}} x """
上述问题 接近源码的做法实现一个支持协程线程的自定义类似threading.local 对象
# -*- coding: utf-8 -*- """ 模仿 flask中运用了一些面向对象的方法: __getattr__,__setattr__ """ import threading try: # 优先用协程的 # 如果是单线程多协程,导入获取协程唯一标识的 from greenlet import getcurrent as get_ident # 协程 except ImportError: try: # 如果是多线程导入获取线程唯一标识的 from thread import get_ident except ImportError: # 如果是多线程导入获取线程唯一标识的 from _thread import get_ident # 线程 class Local(object): def __init__(self): object.__setattr__(self, "__storage__", {}) object.__setattr__(self, "__ident_func__", get_ident) def __getattr__(self, name): try: return self.__storage__[self.__ident_func__()][name] except KeyError: raise AttributeError(name) def __setattr__(self, name, value): ident = self.__ident_func__() storage = self.__storage__ try: storage[ident][name] = value except KeyError: storage[ident] = {name: value} # 获取一个线程对象 local_obj = Local() # 获取每一个线程的唯一标识 def task(num): local_obj.name = num import time time.sleep(1) print(local_obj.name, threading.current_thread().name) for i in range(20): th = threading.Thread(target=task, args=(i,), name='线程%s' % i) th.start() """ 0 线程0 3 线程3 4 线程4 1 线程1 2 线程2 8 线程8 7 线程7 5 线程5 6 线程6 10 线程10 9 线程9 11 线程11 12 线程12 15 线程15 14 线程14 13 线程13 19 线程19 16 线程16 18 线程18 17 线程17 """
flask 源码实现方式
try: from greenlet import getcurrent as get_ident except ImportError: try: from thread import get_ident except ImportError: from _thread import get_ident class Local(object): def __init__(self): """当类 实例化产生函数的时候初始化的时候被调用""" object.__setattr__(self, "__storage__", {}) object.__setattr__(self, "__ident_func__", get_ident) def __call__(self, proxy): """ 当类实例化的对象 被 调用的时候执行该函数 """ """Create a proxy for a name.""" return LocalProxy(self, proxy) def __release_local__(self): self.__storage__.pop(self.__ident_func__(), None) def __getattr__(self, name): """定义当用户试图获取一个不存在的属性时的行为""" try: return self.__storage__[self.__ident_func__()][name] except KeyError: raise AttributeError(name) def __setattr__(self, name, value): """定义当一个属性被设置时的行为""" ident = self.__ident_func__() storage = self.__storage__ try: storage[ident][name] = value except KeyError: storage[ident] = {name: value} def __delattr__(self, name): """定义当一个属性被删除时的行为""" try: del self.__storage__[self.__ident_func__()][name] except KeyError: raise AttributeError(name)
PS: flask 中保存请求相关 session相关的对象的在并发的时候的不同(保证数据的安全),都是基于这个 threading.local 实现的