| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191 | # SPDX-License-Identifier: AGPL-3.0-or-laterimport asyncioimport threadingimport concurrent.futuresfrom timeit import default_timerimport httpximport h2.exceptionsfrom .network import get_network, initializefrom .client import get_loopfrom .raise_for_httperror import raise_for_httperror# queue.SimpleQueue: Support Python 3.6try:    from queue import SimpleQueueexcept ImportError:    from queue import Empty    from collections import deque    class SimpleQueue:        """Minimal backport of queue.SimpleQueue"""        def __init__(self):            self._queue = deque()            self._count = threading.Semaphore(0)        def put(self, item):            self._queue.append(item)            self._count.release()        def get(self):            if not self._count.acquire(True):                raise Empty            return self._queue.popleft()THREADLOCAL = threading.local()def reset_time_for_thread():    THREADLOCAL.total_time = 0def get_time_for_thread():    """returns thread's total time or None"""    return THREADLOCAL.__dict__.get('total_time')def set_timeout_for_thread(timeout, start_time=None):    THREADLOCAL.timeout = timeout    THREADLOCAL.start_time = start_timedef set_context_network_name(network_name):    THREADLOCAL.network = get_network(network_name)def get_context_network():    """If set return thread's network.    If unset, return value from :py:obj:`get_network`.    """    return THREADLOCAL.__dict__.get('network') or get_network()def request(method, url, **kwargs):    """same as requests/requests/api.py request(...)"""    time_before_request = default_timer()    # timeout (httpx)    if 'timeout' in kwargs:        timeout = kwargs['timeout']    else:        timeout = getattr(THREADLOCAL, 'timeout', None)        if timeout is not None:            kwargs['timeout'] = timeout    # 2 minutes timeout for the requests without timeout    timeout = timeout or 120    # ajdust actual timeout    timeout += 0.2  # overhead    start_time = getattr(THREADLOCAL, 'start_time', time_before_request)    if start_time:        timeout -= default_timer() - start_time    # raise_for_error    check_for_httperror = True    if 'raise_for_httperror' in kwargs:        check_for_httperror = kwargs['raise_for_httperror']        del kwargs['raise_for_httperror']    # requests compatibility    if isinstance(url, bytes):        url = url.decode()    # network    network = get_context_network()    # do request    future = asyncio.run_coroutine_threadsafe(network.request(method, url, **kwargs), get_loop())    try:        response = future.result(timeout)    except concurrent.futures.TimeoutError as e:        raise httpx.TimeoutException('Timeout', request=None) from e    # requests compatibility    # see also https://www.python-httpx.org/compatibility/#checking-for-4xx5xx-responses    response.ok = not response.is_error    # update total_time.    # See get_time_for_thread() and reset_time_for_thread()    if hasattr(THREADLOCAL, 'total_time'):        time_after_request = default_timer()        THREADLOCAL.total_time += time_after_request - time_before_request    # raise an exception    if check_for_httperror:        raise_for_httperror(response)    return responsedef get(url, **kwargs):    kwargs.setdefault('allow_redirects', True)    return request('get', url, **kwargs)def options(url, **kwargs):    kwargs.setdefault('allow_redirects', True)    return request('options', url, **kwargs)def head(url, **kwargs):    kwargs.setdefault('allow_redirects', False)    return request('head', url, **kwargs)def post(url, data=None, **kwargs):    return request('post', url, data=data, **kwargs)def put(url, data=None, **kwargs):    return request('put', url, data=data, **kwargs)def patch(url, data=None, **kwargs):    return request('patch', url, data=data, **kwargs)def delete(url, **kwargs):    return request('delete', url, **kwargs)async def stream_chunk_to_queue(network, q, method, url, **kwargs):    try:        async with network.stream(method, url, **kwargs) as response:            q.put(response)            async for chunk in response.aiter_bytes(65536):                if len(chunk) > 0:                    q.put(chunk)    except (httpx.HTTPError, OSError, h2.exceptions.ProtocolError) as e:        q.put(e)    finally:        q.put(None)def stream(method, url, **kwargs):    """Replace httpx.stream.    Usage:    stream = poolrequests.stream(...)    response = next(stream)    for chunk in stream:        ...    httpx.Client.stream requires to write the httpx.HTTPTransport version of the    the httpx.AsyncHTTPTransport declared above.    """    q = SimpleQueue()    future = asyncio.run_coroutine_threadsafe(stream_chunk_to_queue(get_network(), q, method, url, **kwargs),                                              get_loop())    chunk_or_exception = q.get()    while chunk_or_exception is not None:        if isinstance(chunk_or_exception, Exception):            raise chunk_or_exception        yield chunk_or_exception        chunk_or_exception = q.get()    return future.result()
 |