__init__.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. # SPDX-License-Identifier: AGPL-3.0-or-later
  2. # lint: pylint
  3. # pylint: disable=missing-module-docstring, missing-function-docstring, global-statement
  4. import asyncio
  5. import threading
  6. import concurrent.futures
  7. from types import MethodType
  8. from timeit import default_timer
  9. import httpx
  10. import h2.exceptions
  11. from .network import get_network, initialize
  12. from .client import get_loop
  13. from .raise_for_httperror import raise_for_httperror
  14. # queue.SimpleQueue: Support Python 3.6
  15. try:
  16. from queue import SimpleQueue
  17. except ImportError:
  18. from queue import Empty
  19. from collections import deque
  20. class SimpleQueue:
  21. """Minimal backport of queue.SimpleQueue"""
  22. def __init__(self):
  23. self._queue = deque()
  24. self._count = threading.Semaphore(0)
  25. def put(self, item):
  26. self._queue.append(item)
  27. self._count.release()
  28. def get(self):
  29. if not self._count.acquire(True): #pylint: disable=consider-using-with
  30. raise Empty
  31. return self._queue.popleft()
  32. THREADLOCAL = threading.local()
  33. """Thread-local data is data for thread specific values."""
  34. def reset_time_for_thread():
  35. global THREADLOCAL
  36. THREADLOCAL.total_time = 0
  37. def get_time_for_thread():
  38. """returns thread's total time or None"""
  39. global THREADLOCAL
  40. return THREADLOCAL.__dict__.get('total_time')
  41. def set_timeout_for_thread(timeout, start_time=None):
  42. global THREADLOCAL
  43. THREADLOCAL.timeout = timeout
  44. THREADLOCAL.start_time = start_time
  45. def set_context_network_name(network_name):
  46. global THREADLOCAL
  47. THREADLOCAL.network = get_network(network_name)
  48. def get_context_network():
  49. """If set return thread's network.
  50. If unset, return value from :py:obj:`get_network`.
  51. """
  52. global THREADLOCAL
  53. return THREADLOCAL.__dict__.get('network') or get_network()
  54. def request(method, url, **kwargs):
  55. """same as requests/requests/api.py request(...)"""
  56. global THREADLOCAL
  57. time_before_request = default_timer()
  58. # timeout (httpx)
  59. if 'timeout' in kwargs:
  60. timeout = kwargs['timeout']
  61. else:
  62. timeout = getattr(THREADLOCAL, 'timeout', None)
  63. if timeout is not None:
  64. kwargs['timeout'] = timeout
  65. # 2 minutes timeout for the requests without timeout
  66. timeout = timeout or 120
  67. # ajdust actual timeout
  68. timeout += 0.2 # overhead
  69. start_time = getattr(THREADLOCAL, 'start_time', time_before_request)
  70. if start_time:
  71. timeout -= default_timer() - start_time
  72. # raise_for_error
  73. check_for_httperror = True
  74. if 'raise_for_httperror' in kwargs:
  75. check_for_httperror = kwargs['raise_for_httperror']
  76. del kwargs['raise_for_httperror']
  77. # requests compatibility
  78. if isinstance(url, bytes):
  79. url = url.decode()
  80. # network
  81. network = get_context_network()
  82. # do request
  83. future = asyncio.run_coroutine_threadsafe(network.request(method, url, **kwargs), get_loop())
  84. try:
  85. response = future.result(timeout)
  86. except concurrent.futures.TimeoutError as e:
  87. raise httpx.TimeoutException('Timeout', request=None) from e
  88. # requests compatibility
  89. # see also https://www.python-httpx.org/compatibility/#checking-for-4xx5xx-responses
  90. response.ok = not response.is_error
  91. # update total_time.
  92. # See get_time_for_thread() and reset_time_for_thread()
  93. if hasattr(THREADLOCAL, 'total_time'):
  94. time_after_request = default_timer()
  95. THREADLOCAL.total_time += time_after_request - time_before_request
  96. # raise an exception
  97. if check_for_httperror:
  98. raise_for_httperror(response)
  99. return response
  100. def get(url, **kwargs):
  101. kwargs.setdefault('allow_redirects', True)
  102. return request('get', url, **kwargs)
  103. def options(url, **kwargs):
  104. kwargs.setdefault('allow_redirects', True)
  105. return request('options', url, **kwargs)
  106. def head(url, **kwargs):
  107. kwargs.setdefault('allow_redirects', False)
  108. return request('head', url, **kwargs)
  109. def post(url, data=None, **kwargs):
  110. return request('post', url, data=data, **kwargs)
  111. def put(url, data=None, **kwargs):
  112. return request('put', url, data=data, **kwargs)
  113. def patch(url, data=None, **kwargs):
  114. return request('patch', url, data=data, **kwargs)
  115. def delete(url, **kwargs):
  116. return request('delete', url, **kwargs)
  117. async def stream_chunk_to_queue(network, queue, method, url, **kwargs):
  118. try:
  119. async with network.stream(method, url, **kwargs) as response:
  120. queue.put(response)
  121. # aiter_raw: access the raw bytes on the response without applying any HTTP content decoding
  122. # https://www.python-httpx.org/quickstart/#streaming-responses
  123. async for chunk in response.aiter_raw(65536):
  124. if len(chunk) > 0:
  125. queue.put(chunk)
  126. except httpx.ResponseClosed:
  127. # the response was closed
  128. pass
  129. except (httpx.HTTPError, OSError, h2.exceptions.ProtocolError) as e:
  130. queue.put(e)
  131. finally:
  132. queue.put(None)
  133. def _close_response_method(self):
  134. asyncio.run_coroutine_threadsafe(
  135. self.aclose(),
  136. get_loop()
  137. )
  138. def stream(method, url, **kwargs):
  139. """Replace httpx.stream.
  140. Usage:
  141. stream = poolrequests.stream(...)
  142. response = next(stream)
  143. for chunk in stream:
  144. ...
  145. httpx.Client.stream requires to write the httpx.HTTPTransport version of the
  146. the httpx.AsyncHTTPTransport declared above.
  147. """
  148. queue = SimpleQueue()
  149. future = asyncio.run_coroutine_threadsafe(
  150. stream_chunk_to_queue(get_network(), queue, method, url, **kwargs),
  151. get_loop()
  152. )
  153. # yield response
  154. response = queue.get()
  155. if isinstance(response, Exception):
  156. raise response
  157. response.close = MethodType(_close_response_method, response)
  158. yield response
  159. # yield chunks
  160. chunk_or_exception = queue.get()
  161. while chunk_or_exception is not None:
  162. if isinstance(chunk_or_exception, Exception):
  163. raise chunk_or_exception
  164. yield chunk_or_exception
  165. chunk_or_exception = queue.get()
  166. future.result()