Browse Source

Merge pull request #343 from dalf/fix-checker-memory-issue

[fix] checker: fix memory usage
Alexandre Flament 3 years ago
parent
commit
c23aa5760c
3 changed files with 81 additions and 35 deletions
  1. 31 16
      searx/network/__init__.py
  2. 7 0
      searx/network/client.py
  3. 43 19
      searx/search/checker/impl.py

+ 31 - 16
searx/network/__init__.py

@@ -9,6 +9,7 @@ from types import MethodType
 from timeit import default_timer
 
 import httpx
+import anyio
 import h2.exceptions
 
 from .network import get_network, initialize
@@ -166,7 +167,7 @@ async def stream_chunk_to_queue(network, queue, method, url, **kwargs):
             async for chunk in response.aiter_raw(65536):
                 if len(chunk) > 0:
                     queue.put(chunk)
-    except httpx.StreamClosed:
+    except (httpx.StreamClosed, anyio.ClosedResourceError):
         # the response was queued before the exception.
         # the exception was raised on aiter_raw.
         # we do nothing here: in the finally block, None will be queued
@@ -183,11 +184,35 @@ async def stream_chunk_to_queue(network, queue, method, url, **kwargs):
         queue.put(None)
 
 
+def _stream_generator(method, url, **kwargs):
+    queue = SimpleQueue()
+    network = get_context_network()
+    future = asyncio.run_coroutine_threadsafe(
+        stream_chunk_to_queue(network, queue, method, url, **kwargs),
+        get_loop()
+    )
+
+    # yield chunks
+    obj_or_exception = queue.get()
+    while obj_or_exception is not None:
+        if isinstance(obj_or_exception, Exception):
+            raise obj_or_exception
+        yield obj_or_exception
+        obj_or_exception = queue.get()
+    future.result()
+
+
 def _close_response_method(self):
     asyncio.run_coroutine_threadsafe(
         self.aclose(),
         get_loop()
     )
+    # reach the end of _self.generator ( _stream_generator ) to an avoid memory leak.
+    # it makes sure that :
+    # * the httpx response is closed (see the stream_chunk_to_queue function)
+    # * to call future.result() in _stream_generator
+    for _ in self._generator:  # pylint: disable=protected-access
+        continue
 
 
 def stream(method, url, **kwargs):
@@ -202,25 +227,15 @@ def stream(method, url, **kwargs):
     httpx.Client.stream requires to write the httpx.HTTPTransport version of the
     the httpx.AsyncHTTPTransport declared above.
     """
-    queue = SimpleQueue()
-    network = get_context_network()
-    future = asyncio.run_coroutine_threadsafe(
-        stream_chunk_to_queue(network, queue, method, url, **kwargs),
-        get_loop()
-    )
+    generator = _stream_generator(method, url, **kwargs)
 
     # yield response
-    response = queue.get()
+    response = next(generator)  # pylint: disable=stop-iteration-return
     if isinstance(response, Exception):
         raise response
+
+    response._generator = generator  # pylint: disable=protected-access
     response.close = MethodType(_close_response_method, response)
     yield response
 
-    # yield chunks
-    chunk_or_exception = queue.get()
-    while chunk_or_exception is not None:
-        if isinstance(chunk_or_exception, Exception):
-            raise chunk_or_exception
-        yield chunk_or_exception
-        chunk_or_exception = queue.get()
-    future.result()
+    yield from generator

+ 7 - 0
searx/network/client.py

@@ -6,6 +6,7 @@ import asyncio
 import logging
 import threading
 
+import anyio
 import httpcore
 import httpx
 from httpx_socks import AsyncProxyTransport
@@ -102,6 +103,9 @@ class AsyncProxyTransportFixed(AsyncProxyTransport):
                 # then each new request creates a new stream and raise the same WriteError
                 await close_connections_for_url(self, url)
                 raise e
+            except anyio.ClosedResourceError as e:
+                await close_connections_for_url(self, url)
+                raise httpx.CloseError from e
             except httpx.RemoteProtocolError as e:
                 # in case of httpx.RemoteProtocolError: Server disconnected
                 await close_connections_for_url(self, url)
@@ -130,6 +134,9 @@ class AsyncHTTPTransportFixed(httpx.AsyncHTTPTransport):
                 # then each new request creates a new stream and raise the same WriteError
                 await close_connections_for_url(self._pool, url)
                 raise e
+            except anyio.ClosedResourceError as e:
+                await close_connections_for_url(self._pool, url)
+                raise httpx.CloseError from e
             except httpx.RemoteProtocolError as e:
                 # in case of httpx.RemoteProtocolError: Server disconnected
                 await close_connections_for_url(self._pool, url)

+ 43 - 19
searx/search/checker/impl.py

@@ -1,5 +1,6 @@
 # SPDX-License-Identifier: AGPL-3.0-or-later
 
+import gc
 import typing
 import types
 import functools
@@ -14,6 +15,7 @@ from langdetect.lang_detect_exception import LangDetectException
 import httpx
 
 from searx import network, logger
+from searx.utils import gen_useragent
 from searx.results import ResultContainer
 from searx.search.models import SearchQuery, EngineRef
 from searx.search.processors import EngineProcessor
@@ -58,27 +60,20 @@ def _is_url(url):
 
 
 @functools.lru_cache(maxsize=8192)
-def _is_url_image(image_url):
-    if not isinstance(image_url, str):
-        return False
-
-    if image_url.startswith('//'):
-        image_url = 'https:' + image_url
-
-    if image_url.startswith('data:'):
-        return image_url.startswith('data:image/')
-
-    if not _is_url(image_url):
-        return False
-
+def _download_and_check_if_image(image_url: str) -> bool:
+    """Download an URL and check if the Content-Type starts with "image/"
+    This function should not be called directly: use _is_url_image
+    otherwise the cache of functools.lru_cache contains data: URL which might be huge.
+    """
     retry = 2
 
     while retry > 0:
         a = time()
         try:
-            network.set_timeout_for_thread(10.0, time())
-            r = network.get(image_url, timeout=10.0, allow_redirects=True, headers={
-                'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:84.0) Gecko/20100101 Firefox/84.0',
+            # use "image_proxy" (avoid HTTP/2)
+            network.set_context_network_name('image_proxy')
+            stream = network.stream('GET', image_url, timeout=10.0, allow_redirects=True, headers={
+                'User-Agent': gen_useragent(),
                 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8',
                 'Accept-Language': 'en-US;q=0.5,en;q=0.3',
                 'Accept-Encoding': 'gzip, deflate, br',
@@ -88,15 +83,40 @@ def _is_url_image(image_url):
                 'Sec-GPC': '1',
                 'Cache-Control': 'max-age=0'
             })
-            if r.headers["content-type"].startswith('image/'):
-                return True
-            return False
+            r = next(stream)
+            r.close()
+            if r.status_code == 200:
+                is_image = r.headers.get('content-type', '').startswith('image/')
+            else:
+                is_image = False
+            del r
+            del stream
+            return is_image
         except httpx.TimeoutException:
             logger.error('Timeout for %s: %i', image_url, int(time() - a))
             retry -= 1
         except httpx.HTTPError:
             logger.exception('Exception for %s', image_url)
             return False
+    return False
+
+
+def _is_url_image(image_url) -> bool:
+    """Normalize image_url
+    """
+    if not isinstance(image_url, str):
+        return False
+
+    if image_url.startswith('//'):
+        image_url = 'https:' + image_url
+
+    if image_url.startswith('data:'):
+        return image_url.startswith('data:image/')
+
+    if not _is_url(image_url):
+        return False
+
+    return _download_and_check_if_image(image_url)
 
 
 def _search_query_to_dict(search_query: SearchQuery) -> typing.Dict[str, typing.Any]:
@@ -414,3 +434,7 @@ class Checker:
     def run(self):
         for test_name in self.tests:
             self.run_test(test_name)
+            # clear cache
+            _download_and_check_if_image.cache_clear()
+            # force a garbage collector
+            gc.collect()