| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266 | # SPDX-License-Identifier: AGPL-3.0-or-later# lint: pylint# pylint: disable=missing-module-docstring, global-statementimport asyncioimport threadingimport concurrent.futuresfrom queue import SimpleQueuefrom types import MethodTypefrom timeit import default_timerfrom typing import Iterable, NamedTuple, Tuple, List, Dict, Unionfrom contextlib import contextmanagerimport httpximport anyiofrom .network import get_network, initialize, check_network_configuration  # pylint:disable=cyclic-importfrom .client import get_loopfrom .raise_for_httperror import raise_for_httperrorTHREADLOCAL = threading.local()"""Thread-local data is data for thread specific values."""def reset_time_for_thread():    THREADLOCAL.total_time = 0def get_time_for_thread():    """returns thread's total time or None"""    return THREADLOCAL.__dict__.get('total_time')def set_timeout_for_thread(timeout, start_time=None):    THREADLOCAL.timeout = timeout    THREADLOCAL.start_time = start_timedef set_context_network_name(network_name):    THREADLOCAL.network = get_network(network_name)def get_context_network():    """If set return thread's network.    If unset, return value from :py:obj:`get_network`.    """    return THREADLOCAL.__dict__.get('network') or get_network()@contextmanagerdef _record_http_time():    # pylint: disable=too-many-branches    time_before_request = default_timer()    start_time = getattr(THREADLOCAL, 'start_time', time_before_request)    try:        yield start_time    finally:        # update total_time.        # See get_time_for_thread() and reset_time_for_thread()        if hasattr(THREADLOCAL, 'total_time'):            time_after_request = default_timer()            THREADLOCAL.total_time += time_after_request - time_before_requestdef _get_timeout(start_time, kwargs):    # pylint: disable=too-many-branches    # timeout (httpx)    if 'timeout' in kwargs:        timeout = kwargs['timeout']    else:        timeout = getattr(THREADLOCAL, 'timeout', None)        if timeout is not None:            kwargs['timeout'] = timeout    # 2 minutes timeout for the requests without timeout    timeout = timeout or 120    # adjust actual timeout    timeout += 0.2  # overhead    if start_time:        timeout -= default_timer() - start_time    return timeoutdef request(method, url, **kwargs):    """same as requests/requests/api.py request(...)"""    with _record_http_time() as start_time:        network = get_context_network()        timeout = _get_timeout(start_time, kwargs)        future = asyncio.run_coroutine_threadsafe(network.request(method, url, **kwargs), get_loop())        try:            return future.result(timeout)        except concurrent.futures.TimeoutError as e:            raise httpx.TimeoutException('Timeout', request=None) from edef multi_requests(request_list: List["Request"]) -> List[Union[httpx.Response, Exception]]:    """send multiple HTTP requests in parallel. Wait for all requests to finish."""    with _record_http_time() as start_time:        # send the requests        network = get_context_network()        loop = get_loop()        future_list = []        for request_desc in request_list:            timeout = _get_timeout(start_time, request_desc.kwargs)            future = asyncio.run_coroutine_threadsafe(                network.request(request_desc.method, request_desc.url, **request_desc.kwargs), loop            )            future_list.append((future, timeout))        # read the responses        responses = []        for future, timeout in future_list:            try:                responses.append(future.result(timeout))            except concurrent.futures.TimeoutError:                responses.append(httpx.TimeoutException('Timeout', request=None))            except Exception as e:  # pylint: disable=broad-except                responses.append(e)        return responsesclass Request(NamedTuple):    """Request description for the multi_requests function"""    method: str    url: str    kwargs: Dict[str, str] = {}    @staticmethod    def get(url, **kwargs):        return Request('GET', url, kwargs)    @staticmethod    def options(url, **kwargs):        return Request('OPTIONS', url, kwargs)    @staticmethod    def head(url, **kwargs):        return Request('HEAD', url, kwargs)    @staticmethod    def post(url, **kwargs):        return Request('POST', url, kwargs)    @staticmethod    def put(url, **kwargs):        return Request('PUT', url, kwargs)    @staticmethod    def patch(url, **kwargs):        return Request('PATCH', url, kwargs)    @staticmethod    def delete(url, **kwargs):        return Request('DELETE', url, kwargs)def get(url, **kwargs):    kwargs.setdefault('allow_redirects', True)    return request('get', url, **kwargs)def options(url, **kwargs):    kwargs.setdefault('allow_redirects', True)    return request('options', url, **kwargs)def head(url, **kwargs):    kwargs.setdefault('allow_redirects', False)    return request('head', url, **kwargs)def post(url, data=None, **kwargs):    return request('post', url, data=data, **kwargs)def put(url, data=None, **kwargs):    return request('put', url, data=data, **kwargs)def patch(url, data=None, **kwargs):    return request('patch', url, data=data, **kwargs)def delete(url, **kwargs):    return request('delete', url, **kwargs)async def stream_chunk_to_queue(network, queue, method, url, **kwargs):    try:        async with await network.stream(method, url, **kwargs) as response:            queue.put(response)            # aiter_raw: access the raw bytes on the response without applying any HTTP content decoding            # https://www.python-httpx.org/quickstart/#streaming-responses            async for chunk in response.aiter_raw(65536):                if len(chunk) > 0:                    queue.put(chunk)    except (httpx.StreamClosed, anyio.ClosedResourceError):        # the response was queued before the exception.        # the exception was raised on aiter_raw.        # we do nothing here: in the finally block, None will be queued        # so stream(method, url, **kwargs) generator can stop        pass    except Exception as e:  # pylint: disable=broad-except        # broad except to avoid this scenario:        # exception in network.stream(method, url, **kwargs)        # -> the exception is not catch here        # -> queue None (in finally)        # -> the function below steam(method, url, **kwargs) has nothing to return        queue.put(e)    finally:        queue.put(None)def _stream_generator(method, url, **kwargs):    queue = SimpleQueue()    network = get_context_network()    future = asyncio.run_coroutine_threadsafe(stream_chunk_to_queue(network, queue, method, url, **kwargs), get_loop())    # yield chunks    obj_or_exception = queue.get()    while obj_or_exception is not None:        if isinstance(obj_or_exception, Exception):            raise obj_or_exception        yield obj_or_exception        obj_or_exception = queue.get()    future.result()def _close_response_method(self):    asyncio.run_coroutine_threadsafe(self.aclose(), get_loop())    # reach the end of _self.generator ( _stream_generator ) to an avoid memory leak.    # it makes sure that :    # * the httpx response is closed (see the stream_chunk_to_queue function)    # * to call future.result() in _stream_generator    for _ in self._generator:  # pylint: disable=protected-access        continuedef stream(method, url, **kwargs) -> Tuple[httpx.Response, Iterable[bytes]]:    """Replace httpx.stream.    Usage:    response, stream = poolrequests.stream(...)    for chunk in stream:        ...    httpx.Client.stream requires to write the httpx.HTTPTransport version of the    the httpx.AsyncHTTPTransport declared above.    """    generator = _stream_generator(method, url, **kwargs)    # yield response    response = next(generator)  # pylint: disable=stop-iteration-return    if isinstance(response, Exception):        raise response    response._generator = generator  # pylint: disable=protected-access    response.close = MethodType(_close_response_method, response)    return response, generator
 |