__init__.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. # SPDX-License-Identifier: AGPL-3.0-or-later
  2. import asyncio
  3. import threading
  4. import concurrent.futures
  5. from timeit import default_timer
  6. import httpx
  7. import h2.exceptions
  8. from .network import get_network, initialize
  9. from .client import get_loop
  10. from .raise_for_httperror import raise_for_httperror
  11. # queue.SimpleQueue: Support Python 3.6
  12. try:
  13. from queue import SimpleQueue
  14. except ImportError:
  15. from queue import Empty
  16. from collections import deque
  17. class SimpleQueue:
  18. """Minimal backport of queue.SimpleQueue"""
  19. def __init__(self):
  20. self._queue = deque()
  21. self._count = threading.Semaphore(0)
  22. def put(self, item):
  23. self._queue.append(item)
  24. self._count.release()
  25. def get(self):
  26. if not self._count.acquire(True):
  27. raise Empty
  28. return self._queue.popleft()
  29. THREADLOCAL = threading.local()
  30. def reset_time_for_thread():
  31. THREADLOCAL.total_time = 0
  32. def get_time_for_thread():
  33. return THREADLOCAL.total_time
  34. def set_timeout_for_thread(timeout, start_time=None):
  35. THREADLOCAL.timeout = timeout
  36. THREADLOCAL.start_time = start_time
  37. def set_context_network_name(network_name):
  38. THREADLOCAL.network = get_network(network_name)
  39. def get_context_network():
  40. try:
  41. return THREADLOCAL.network
  42. except AttributeError:
  43. return get_network()
  44. def request(method, url, **kwargs):
  45. """same as requests/requests/api.py request(...)"""
  46. time_before_request = default_timer()
  47. # timeout (httpx)
  48. if 'timeout' in kwargs:
  49. timeout = kwargs['timeout']
  50. else:
  51. timeout = getattr(THREADLOCAL, 'timeout', None)
  52. if timeout is not None:
  53. kwargs['timeout'] = timeout
  54. # 2 minutes timeout for the requests without timeout
  55. timeout = timeout or 120
  56. # ajdust actual timeout
  57. timeout += 0.2 # overhead
  58. start_time = getattr(THREADLOCAL, 'start_time', time_before_request)
  59. if start_time:
  60. timeout -= default_timer() - start_time
  61. # raise_for_error
  62. check_for_httperror = True
  63. if 'raise_for_httperror' in kwargs:
  64. check_for_httperror = kwargs['raise_for_httperror']
  65. del kwargs['raise_for_httperror']
  66. # requests compatibility
  67. if isinstance(url, bytes):
  68. url = url.decode()
  69. # network
  70. network = get_context_network()
  71. # do request
  72. future = asyncio.run_coroutine_threadsafe(network.request(method, url, **kwargs), get_loop())
  73. try:
  74. response = future.result(timeout)
  75. except concurrent.futures.TimeoutError as e:
  76. raise httpx.TimeoutException('Timeout', request=None) from e
  77. # requests compatibility
  78. # see also https://www.python-httpx.org/compatibility/#checking-for-4xx5xx-responses
  79. response.ok = not response.is_error
  80. # update total_time.
  81. # See get_time_for_thread() and reset_time_for_thread()
  82. if hasattr(THREADLOCAL, 'total_time'):
  83. time_after_request = default_timer()
  84. THREADLOCAL.total_time += time_after_request - time_before_request
  85. # raise an exception
  86. if check_for_httperror:
  87. raise_for_httperror(response)
  88. return response
  89. def get(url, **kwargs):
  90. kwargs.setdefault('allow_redirects', True)
  91. return request('get', url, **kwargs)
  92. def options(url, **kwargs):
  93. kwargs.setdefault('allow_redirects', True)
  94. return request('options', url, **kwargs)
  95. def head(url, **kwargs):
  96. kwargs.setdefault('allow_redirects', False)
  97. return request('head', url, **kwargs)
  98. def post(url, data=None, **kwargs):
  99. return request('post', url, data=data, **kwargs)
  100. def put(url, data=None, **kwargs):
  101. return request('put', url, data=data, **kwargs)
  102. def patch(url, data=None, **kwargs):
  103. return request('patch', url, data=data, **kwargs)
  104. def delete(url, **kwargs):
  105. return request('delete', url, **kwargs)
  106. async def stream_chunk_to_queue(network, q, method, url, **kwargs):
  107. try:
  108. async with network.stream(method, url, **kwargs) as response:
  109. q.put(response)
  110. async for chunk in response.aiter_bytes(65536):
  111. if len(chunk) > 0:
  112. q.put(chunk)
  113. except (httpx.HTTPError, OSError, h2.exceptions.ProtocolError) as e:
  114. q.put(e)
  115. finally:
  116. q.put(None)
  117. def stream(method, url, **kwargs):
  118. """Replace httpx.stream.
  119. Usage:
  120. stream = poolrequests.stream(...)
  121. response = next(stream)
  122. for chunk in stream:
  123. ...
  124. httpx.Client.stream requires to write the httpx.HTTPTransport version of the
  125. the httpx.AsyncHTTPTransport declared above.
  126. """
  127. q = SimpleQueue()
  128. future = asyncio.run_coroutine_threadsafe(stream_chunk_to_queue(get_network(), q, method, url, **kwargs),
  129. get_loop())
  130. chunk_or_exception = q.get()
  131. while chunk_or_exception is not None:
  132. if isinstance(chunk_or_exception, Exception):
  133. raise chunk_or_exception
  134. yield chunk_or_exception
  135. chunk_or_exception = q.get()
  136. return future.result()