Browse Source

[fix] searx.network.stream: fix memory leak

Alexandre Flament 3 years ago
parent
commit
29893cf816
3 changed files with 42 additions and 17 deletions
  1. 31 16
      searx/network/__init__.py
  2. 7 0
      searx/network/client.py
  3. 4 1
      searx/search/checker/impl.py

+ 31 - 16
searx/network/__init__.py

@@ -9,6 +9,7 @@ from types import MethodType
 from timeit import default_timer
 from timeit import default_timer
 
 
 import httpx
 import httpx
+import anyio
 import h2.exceptions
 import h2.exceptions
 
 
 from .network import get_network, initialize
 from .network import get_network, initialize
@@ -166,7 +167,7 @@ async def stream_chunk_to_queue(network, queue, method, url, **kwargs):
             async for chunk in response.aiter_raw(65536):
             async for chunk in response.aiter_raw(65536):
                 if len(chunk) > 0:
                 if len(chunk) > 0:
                     queue.put(chunk)
                     queue.put(chunk)
-    except httpx.StreamClosed:
+    except (httpx.StreamClosed, anyio.ClosedResourceError):
         # the response was queued before the exception.
         # the response was queued before the exception.
         # the exception was raised on aiter_raw.
         # the exception was raised on aiter_raw.
         # we do nothing here: in the finally block, None will be queued
         # we do nothing here: in the finally block, None will be queued
@@ -183,11 +184,35 @@ async def stream_chunk_to_queue(network, queue, method, url, **kwargs):
         queue.put(None)
         queue.put(None)
 
 
 
 
+def _stream_generator(method, url, **kwargs):
+    queue = SimpleQueue()
+    network = get_context_network()
+    future = asyncio.run_coroutine_threadsafe(
+        stream_chunk_to_queue(network, queue, method, url, **kwargs),
+        get_loop()
+    )
+
+    # yield chunks
+    obj_or_exception = queue.get()
+    while obj_or_exception is not None:
+        if isinstance(obj_or_exception, Exception):
+            raise obj_or_exception
+        yield obj_or_exception
+        obj_or_exception = queue.get()
+    future.result()
+
+
 def _close_response_method(self):
 def _close_response_method(self):
     asyncio.run_coroutine_threadsafe(
     asyncio.run_coroutine_threadsafe(
         self.aclose(),
         self.aclose(),
         get_loop()
         get_loop()
     )
     )
+    # reach the end of _self.generator ( _stream_generator ) to an avoid memory leak.
+    # it makes sure that :
+    # * the httpx response is closed (see the stream_chunk_to_queue function)
+    # * to call future.result() in _stream_generator
+    for _ in self._generator:  # pylint: disable=protected-access
+        continue
 
 
 
 
 def stream(method, url, **kwargs):
 def stream(method, url, **kwargs):
@@ -202,25 +227,15 @@ def stream(method, url, **kwargs):
     httpx.Client.stream requires to write the httpx.HTTPTransport version of the
     httpx.Client.stream requires to write the httpx.HTTPTransport version of the
     the httpx.AsyncHTTPTransport declared above.
     the httpx.AsyncHTTPTransport declared above.
     """
     """
-    queue = SimpleQueue()
-    network = get_context_network()
-    future = asyncio.run_coroutine_threadsafe(
-        stream_chunk_to_queue(network, queue, method, url, **kwargs),
-        get_loop()
-    )
+    generator = _stream_generator(method, url, **kwargs)
 
 
     # yield response
     # yield response
-    response = queue.get()
+    response = next(generator)  # pylint: disable=stop-iteration-return
     if isinstance(response, Exception):
     if isinstance(response, Exception):
         raise response
         raise response
+
+    response._generator = generator  # pylint: disable=protected-access
     response.close = MethodType(_close_response_method, response)
     response.close = MethodType(_close_response_method, response)
     yield response
     yield response
 
 
-    # yield chunks
-    chunk_or_exception = queue.get()
-    while chunk_or_exception is not None:
-        if isinstance(chunk_or_exception, Exception):
-            raise chunk_or_exception
-        yield chunk_or_exception
-        chunk_or_exception = queue.get()
-    future.result()
+    yield from generator

+ 7 - 0
searx/network/client.py

@@ -6,6 +6,7 @@ import asyncio
 import logging
 import logging
 import threading
 import threading
 
 
+import anyio
 import httpcore
 import httpcore
 import httpx
 import httpx
 from httpx_socks import AsyncProxyTransport
 from httpx_socks import AsyncProxyTransport
@@ -102,6 +103,9 @@ class AsyncProxyTransportFixed(AsyncProxyTransport):
                 # then each new request creates a new stream and raise the same WriteError
                 # then each new request creates a new stream and raise the same WriteError
                 await close_connections_for_url(self, url)
                 await close_connections_for_url(self, url)
                 raise e
                 raise e
+            except anyio.ClosedResourceError as e:
+                await close_connections_for_url(self, url)
+                raise httpx.CloseError from e
             except httpx.RemoteProtocolError as e:
             except httpx.RemoteProtocolError as e:
                 # in case of httpx.RemoteProtocolError: Server disconnected
                 # in case of httpx.RemoteProtocolError: Server disconnected
                 await close_connections_for_url(self, url)
                 await close_connections_for_url(self, url)
@@ -130,6 +134,9 @@ class AsyncHTTPTransportFixed(httpx.AsyncHTTPTransport):
                 # then each new request creates a new stream and raise the same WriteError
                 # then each new request creates a new stream and raise the same WriteError
                 await close_connections_for_url(self._pool, url)
                 await close_connections_for_url(self._pool, url)
                 raise e
                 raise e
+            except anyio.ClosedResourceError as e:
+                await close_connections_for_url(self._pool, url)
+                raise httpx.CloseError from e
             except httpx.RemoteProtocolError as e:
             except httpx.RemoteProtocolError as e:
                 # in case of httpx.RemoteProtocolError: Server disconnected
                 # in case of httpx.RemoteProtocolError: Server disconnected
                 await close_connections_for_url(self._pool, url)
                 await close_connections_for_url(self._pool, url)

+ 4 - 1
searx/search/checker/impl.py

@@ -85,7 +85,10 @@ def _download_and_check_if_image(image_url: str) -> bool:
             })
             })
             r = next(stream)
             r = next(stream)
             r.close()
             r.close()
-            is_image = r.headers["content-type"].startswith('image/')
+            if r.status_code == 200:
+                is_image = r.headers.get('content-type', '').startswith('image/')
+            else:
+                is_image = False
             del r
             del r
             del stream
             del stream
             return is_image
             return is_image