__init__.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. # SPDX-License-Identifier: AGPL-3.0-or-later
  2. # lint: pylint
  3. # pylint: disable=missing-module-docstring, global-statement
  4. import asyncio
  5. import threading
  6. import concurrent.futures
  7. from queue import SimpleQueue
  8. from types import MethodType
  9. from timeit import default_timer
  10. from typing import Iterable, Tuple
  11. import httpx
  12. import anyio
  13. from .network import get_network, initialize, check_network_configuration
  14. from .client import get_loop
  15. from .raise_for_httperror import raise_for_httperror
  16. THREADLOCAL = threading.local()
  17. """Thread-local data is data for thread specific values."""
  18. def reset_time_for_thread():
  19. THREADLOCAL.total_time = 0
  20. def get_time_for_thread():
  21. """returns thread's total time or None"""
  22. return THREADLOCAL.__dict__.get('total_time')
  23. def set_timeout_for_thread(timeout, start_time=None):
  24. THREADLOCAL.timeout = timeout
  25. THREADLOCAL.start_time = start_time
  26. def set_context_network_name(network_name):
  27. THREADLOCAL.network = get_network(network_name)
  28. def get_context_network():
  29. """If set return thread's network.
  30. If unset, return value from :py:obj:`get_network`.
  31. """
  32. return THREADLOCAL.__dict__.get('network') or get_network()
  33. def request(method, url, **kwargs):
  34. """same as requests/requests/api.py request(...)"""
  35. time_before_request = default_timer()
  36. # timeout (httpx)
  37. if 'timeout' in kwargs:
  38. timeout = kwargs['timeout']
  39. else:
  40. timeout = getattr(THREADLOCAL, 'timeout', None)
  41. if timeout is not None:
  42. kwargs['timeout'] = timeout
  43. # 2 minutes timeout for the requests without timeout
  44. timeout = timeout or 120
  45. # ajdust actual timeout
  46. timeout += 0.2 # overhead
  47. start_time = getattr(THREADLOCAL, 'start_time', time_before_request)
  48. if start_time:
  49. timeout -= default_timer() - start_time
  50. # raise_for_error
  51. check_for_httperror = True
  52. if 'raise_for_httperror' in kwargs:
  53. check_for_httperror = kwargs['raise_for_httperror']
  54. del kwargs['raise_for_httperror']
  55. # requests compatibility
  56. if isinstance(url, bytes):
  57. url = url.decode()
  58. # network
  59. network = get_context_network()
  60. # do request
  61. future = asyncio.run_coroutine_threadsafe(network.request(method, url, **kwargs), get_loop())
  62. try:
  63. response = future.result(timeout)
  64. except concurrent.futures.TimeoutError as e:
  65. raise httpx.TimeoutException('Timeout', request=None) from e
  66. # requests compatibility
  67. # see also https://www.python-httpx.org/compatibility/#checking-for-4xx5xx-responses
  68. response.ok = not response.is_error
  69. # update total_time.
  70. # See get_time_for_thread() and reset_time_for_thread()
  71. if hasattr(THREADLOCAL, 'total_time'):
  72. time_after_request = default_timer()
  73. THREADLOCAL.total_time += time_after_request - time_before_request
  74. # raise an exception
  75. if check_for_httperror:
  76. raise_for_httperror(response)
  77. return response
  78. def get(url, **kwargs):
  79. kwargs.setdefault('allow_redirects', True)
  80. return request('get', url, **kwargs)
  81. def options(url, **kwargs):
  82. kwargs.setdefault('allow_redirects', True)
  83. return request('options', url, **kwargs)
  84. def head(url, **kwargs):
  85. kwargs.setdefault('allow_redirects', False)
  86. return request('head', url, **kwargs)
  87. def post(url, data=None, **kwargs):
  88. return request('post', url, data=data, **kwargs)
  89. def put(url, data=None, **kwargs):
  90. return request('put', url, data=data, **kwargs)
  91. def patch(url, data=None, **kwargs):
  92. return request('patch', url, data=data, **kwargs)
  93. def delete(url, **kwargs):
  94. return request('delete', url, **kwargs)
  95. async def stream_chunk_to_queue(network, queue, method, url, **kwargs):
  96. try:
  97. async with await network.stream(method, url, **kwargs) as response:
  98. queue.put(response)
  99. # aiter_raw: access the raw bytes on the response without applying any HTTP content decoding
  100. # https://www.python-httpx.org/quickstart/#streaming-responses
  101. async for chunk in response.aiter_raw(65536):
  102. if len(chunk) > 0:
  103. queue.put(chunk)
  104. except (httpx.StreamClosed, anyio.ClosedResourceError):
  105. # the response was queued before the exception.
  106. # the exception was raised on aiter_raw.
  107. # we do nothing here: in the finally block, None will be queued
  108. # so stream(method, url, **kwargs) generator can stop
  109. pass
  110. except Exception as e: # pylint: disable=broad-except
  111. # broad except to avoid this scenario:
  112. # exception in network.stream(method, url, **kwargs)
  113. # -> the exception is not catch here
  114. # -> queue None (in finally)
  115. # -> the function below steam(method, url, **kwargs) has nothing to return
  116. queue.put(e)
  117. finally:
  118. queue.put(None)
  119. def _stream_generator(method, url, **kwargs):
  120. queue = SimpleQueue()
  121. network = get_context_network()
  122. future = asyncio.run_coroutine_threadsafe(stream_chunk_to_queue(network, queue, method, url, **kwargs), get_loop())
  123. # yield chunks
  124. obj_or_exception = queue.get()
  125. while obj_or_exception is not None:
  126. if isinstance(obj_or_exception, Exception):
  127. raise obj_or_exception
  128. yield obj_or_exception
  129. obj_or_exception = queue.get()
  130. future.result()
  131. def _close_response_method(self):
  132. asyncio.run_coroutine_threadsafe(self.aclose(), get_loop())
  133. # reach the end of _self.generator ( _stream_generator ) to an avoid memory leak.
  134. # it makes sure that :
  135. # * the httpx response is closed (see the stream_chunk_to_queue function)
  136. # * to call future.result() in _stream_generator
  137. for _ in self._generator: # pylint: disable=protected-access
  138. continue
  139. def stream(method, url, **kwargs) -> Tuple[httpx.Response, Iterable[bytes]]:
  140. """Replace httpx.stream.
  141. Usage:
  142. response, stream = poolrequests.stream(...)
  143. for chunk in stream:
  144. ...
  145. httpx.Client.stream requires to write the httpx.HTTPTransport version of the
  146. the httpx.AsyncHTTPTransport declared above.
  147. """
  148. generator = _stream_generator(method, url, **kwargs)
  149. # yield response
  150. response = next(generator) # pylint: disable=stop-iteration-return
  151. if isinstance(response, Exception):
  152. raise response
  153. response._generator = generator # pylint: disable=protected-access
  154. response.close = MethodType(_close_response_method, response)
  155. return response, generator