Browse Source

[enh] verify that Tor proxy works every time searx starts

based on @MarcAbonce commit on searx
Alexandre Flament 3 years ago
parent
commit
f9c6393502

+ 2 - 2
searx/network/__init__.py

@@ -12,7 +12,7 @@ import httpx
 import anyio
 import anyio
 import h2.exceptions
 import h2.exceptions
 
 
-from .network import get_network, initialize
+from .network import get_network, initialize, check_network_configuration
 from .client import get_loop
 from .client import get_loop
 from .raise_for_httperror import raise_for_httperror
 from .raise_for_httperror import raise_for_httperror
 
 
@@ -160,7 +160,7 @@ def delete(url, **kwargs):
 
 
 async def stream_chunk_to_queue(network, queue, method, url, **kwargs):
 async def stream_chunk_to_queue(network, queue, method, url, **kwargs):
     try:
     try:
-        async with network.stream(method, url, **kwargs) as response:
+        async with await network.stream(method, url, **kwargs) as response:
             queue.put(response)
             queue.put(response)
             # aiter_raw: access the raw bytes on the response without applying any HTTP content decoding
             # aiter_raw: access the raw bytes on the response without applying any HTTP content decoding
             # https://www.python-httpx.org/quickstart/#streaming-responses
             # https://www.python-httpx.org/quickstart/#streaming-responses

+ 53 - 7
searx/network/network.py

@@ -11,7 +11,7 @@ from itertools import cycle
 import httpx
 import httpx
 
 
 from searx import logger, searx_debug
 from searx import logger, searx_debug
-from .client import new_client, get_loop
+from .client import new_client, get_loop, AsyncHTTPTransportNoHttp
 
 
 
 
 logger = logger.getChild('network')
 logger = logger.getChild('network')
@@ -42,10 +42,12 @@ class Network:
     __slots__ = (
     __slots__ = (
         'enable_http', 'verify', 'enable_http2',
         'enable_http', 'verify', 'enable_http2',
         'max_connections', 'max_keepalive_connections', 'keepalive_expiry',
         'max_connections', 'max_keepalive_connections', 'keepalive_expiry',
-        'local_addresses', 'proxies', 'max_redirects', 'retries', 'retry_on_http_error',
+        'local_addresses', 'proxies', 'using_tor_proxy', 'max_redirects', 'retries', 'retry_on_http_error',
         '_local_addresses_cycle', '_proxies_cycle', '_clients', '_logger'
         '_local_addresses_cycle', '_proxies_cycle', '_clients', '_logger'
     )
     )
 
 
+    _TOR_CHECK_RESULT = {}
+
     def __init__(
     def __init__(
             # pylint: disable=too-many-arguments
             # pylint: disable=too-many-arguments
             self,
             self,
@@ -56,6 +58,7 @@ class Network:
             max_keepalive_connections=None,
             max_keepalive_connections=None,
             keepalive_expiry=None,
             keepalive_expiry=None,
             proxies=None,
             proxies=None,
+            using_tor_proxy=False,
             local_addresses=None,
             local_addresses=None,
             retries=0,
             retries=0,
             retry_on_http_error=None,
             retry_on_http_error=None,
@@ -69,6 +72,7 @@ class Network:
         self.max_keepalive_connections = max_keepalive_connections
         self.max_keepalive_connections = max_keepalive_connections
         self.keepalive_expiry = keepalive_expiry
         self.keepalive_expiry = keepalive_expiry
         self.proxies = proxies
         self.proxies = proxies
+        self.using_tor_proxy = using_tor_proxy
         self.local_addresses = local_addresses
         self.local_addresses = local_addresses
         self.retries = retries
         self.retries = retries
         self.retry_on_http_error = retry_on_http_error
         self.retry_on_http_error = retry_on_http_error
@@ -144,7 +148,27 @@ class Network:
             f'HTTP Request: {request.method} {request.url} "{response_line}"{content_type}'
             f'HTTP Request: {request.method} {request.url} "{response_line}"{content_type}'
         )
         )
 
 
-    def get_client(self, verify=None, max_redirects=None):
+    @staticmethod
+    async def check_tor_proxy(client: httpx.AsyncClient, proxies) -> bool:
+        if proxies in Network._TOR_CHECK_RESULT:
+            return Network._TOR_CHECK_RESULT[proxies]
+
+        result = True
+        # ignore client._transport because it is not used with all://
+        for transport in client._mounts.values():  # pylint: disable=protected-access
+            if isinstance(transport, AsyncHTTPTransportNoHttp):
+                continue
+            if not getattr(transport, '_rdns', False):
+                result = False
+                break
+        else:
+            response = await client.get('https://check.torproject.org/api/ip')
+            if not response.json()['IsTor']:
+                result = False
+        Network._TOR_CHECK_RESULT[proxies] = result
+        return result
+
+    async def get_client(self, verify=None, max_redirects=None):
         verify = self.verify if verify is None else verify
         verify = self.verify if verify is None else verify
         max_redirects = self.max_redirects if max_redirects is None else max_redirects
         max_redirects = self.max_redirects if max_redirects is None else max_redirects
         local_address = next(self._local_addresses_cycle)
         local_address = next(self._local_addresses_cycle)
@@ -152,7 +176,7 @@ class Network:
         key = (verify, max_redirects, local_address, proxies)
         key = (verify, max_redirects, local_address, proxies)
         hook_log_response = self.log_response if searx_debug else None
         hook_log_response = self.log_response if searx_debug else None
         if key not in self._clients or self._clients[key].is_closed:
         if key not in self._clients or self._clients[key].is_closed:
-            self._clients[key] = new_client(
+            client = new_client(
                 self.enable_http,
                 self.enable_http,
                 verify,
                 verify,
                 self.enable_http2,
                 self.enable_http2,
@@ -165,6 +189,10 @@ class Network:
                 max_redirects,
                 max_redirects,
                 hook_log_response
                 hook_log_response
             )
             )
+            if self.using_tor_proxy and not await self.check_tor_proxy(client, proxies):
+                await client.aclose()
+                raise httpx.ProxyError('Network configuration problem: not using Tor')
+            self._clients[key] = client
         return self._clients[key]
         return self._clients[key]
 
 
     async def aclose(self):
     async def aclose(self):
@@ -197,7 +225,7 @@ class Network:
         retries = self.retries
         retries = self.retries
         while retries >= 0:  # pragma: no cover
         while retries >= 0:  # pragma: no cover
             kwargs_clients = Network.get_kwargs_clients(kwargs)
             kwargs_clients = Network.get_kwargs_clients(kwargs)
-            client = self.get_client(**kwargs_clients)
+            client = await self.get_client(**kwargs_clients)
             try:
             try:
                 response = await client.request(method, url, **kwargs)
                 response = await client.request(method, url, **kwargs)
                 if self.is_valid_respones(response) or retries <= 0:
                 if self.is_valid_respones(response) or retries <= 0:
@@ -207,11 +235,11 @@ class Network:
                     raise e
                     raise e
             retries -= 1
             retries -= 1
 
 
-    def stream(self, method, url, **kwargs):
+    async def stream(self, method, url, **kwargs):
         retries = self.retries
         retries = self.retries
         while retries >= 0:  # pragma: no cover
         while retries >= 0:  # pragma: no cover
             kwargs_clients = Network.get_kwargs_clients(kwargs)
             kwargs_clients = Network.get_kwargs_clients(kwargs)
-            client = self.get_client(**kwargs_clients)
+            client = await self.get_client(**kwargs_clients)
             try:
             try:
                 response = client.stream(method, url, **kwargs)
                 response = client.stream(method, url, **kwargs)
                 if self.is_valid_respones(response) or retries <= 0:
                 if self.is_valid_respones(response) or retries <= 0:
@@ -230,6 +258,23 @@ def get_network(name=None):
     return NETWORKS.get(name or DEFAULT_NAME)
     return NETWORKS.get(name or DEFAULT_NAME)
 
 
 
 
+def check_network_configuration():
+    async def check():
+        exception_count = 0
+        for network in NETWORKS.values():
+            if network.using_tor_proxy:
+                try:
+                    await network.get_client()
+                except Exception:  # pylint: disable=broad-except
+                    network._logger.exception('Error')  # pylint: disable=protected-access
+                    exception_count += 1
+        return exception_count
+    future = asyncio.run_coroutine_threadsafe(check(), get_loop())
+    exception_count = future.result()
+    if exception_count > 0:
+        raise RuntimeError("Invalid network configuration")
+
+
 def initialize(settings_engines=None, settings_outgoing=None):
 def initialize(settings_engines=None, settings_outgoing=None):
     # pylint: disable=import-outside-toplevel)
     # pylint: disable=import-outside-toplevel)
     from searx.engines import engines
     from searx.engines import engines
@@ -249,6 +294,7 @@ def initialize(settings_engines=None, settings_outgoing=None):
         'max_keepalive_connections': settings_outgoing['pool_maxsize'],
         'max_keepalive_connections': settings_outgoing['pool_maxsize'],
         'keepalive_expiry': settings_outgoing['keepalive_expiry'],
         'keepalive_expiry': settings_outgoing['keepalive_expiry'],
         'local_addresses': settings_outgoing['source_ips'],
         'local_addresses': settings_outgoing['source_ips'],
+        'using_tor_proxy': settings_outgoing['using_tor_proxy'],
         'proxies': settings_outgoing['proxies'],
         'proxies': settings_outgoing['proxies'],
         'max_redirects': settings_outgoing['max_redirects'],
         'max_redirects': settings_outgoing['max_redirects'],
         'retries': settings_outgoing['retries'],
         'retries': settings_outgoing['retries'],

+ 4 - 2
searx/search/__init__.py

@@ -15,7 +15,7 @@ from searx import logger
 from searx.plugins import plugins
 from searx.plugins import plugins
 from searx.search.models import EngineRef, SearchQuery
 from searx.search.models import EngineRef, SearchQuery
 from searx.engines import load_engines
 from searx.engines import load_engines
-from searx.network import initialize as initialize_network
+from searx.network import initialize as initialize_network, check_network_configuration
 from searx.metrics import initialize as initialize_metrics, counter_inc, histogram_observe_time
 from searx.metrics import initialize as initialize_metrics, counter_inc, histogram_observe_time
 from searx.search.processors import PROCESSORS, initialize as initialize_processors
 from searx.search.processors import PROCESSORS, initialize as initialize_processors
 from searx.search.checker import initialize as initialize_checker
 from searx.search.checker import initialize as initialize_checker
@@ -24,10 +24,12 @@ from searx.search.checker import initialize as initialize_checker
 logger = logger.getChild('search')
 logger = logger.getChild('search')
 
 
 
 
-def initialize(settings_engines=None, enable_checker=False):
+def initialize(settings_engines=None, enable_checker=False, check_network=False):
     settings_engines = settings_engines or settings['engines']
     settings_engines = settings_engines or settings['engines']
     load_engines(settings_engines)
     load_engines(settings_engines)
     initialize_network(settings_engines, settings['outgoing'])
     initialize_network(settings_engines, settings['outgoing'])
+    if check_network:
+        check_network_configuration()
     initialize_metrics([engine['name'] for engine in settings_engines])
     initialize_metrics([engine['name'] for engine in settings_engines])
     initialize_processors(settings_engines)
     initialize_processors(settings_engines)
     if enable_checker:
     if enable_checker:

+ 1 - 1
searx/webapp.py

@@ -1350,7 +1350,7 @@ if (not werkzeug_reloader
     or (werkzeug_reloader
     or (werkzeug_reloader
         and os.environ.get("WERKZEUG_RUN_MAIN") == "true") ):
         and os.environ.get("WERKZEUG_RUN_MAIN") == "true") ):
     plugin_initialize(app)
     plugin_initialize(app)
-    search_initialize(enable_checker=True)
+    search_initialize(enable_checker=True, check_network=True)
 
 
 
 
 def run():
 def run():

+ 10 - 10
tests/unit/network/test_network.py

@@ -90,12 +90,12 @@ class TestNetwork(SearxTestCase):
 
 
     async def test_get_client(self):
     async def test_get_client(self):
         network = Network(verify=True)
         network = Network(verify=True)
-        client1 = network.get_client()
-        client2 = network.get_client(verify=True)
-        client3 = network.get_client(max_redirects=10)
-        client4 = network.get_client(verify=True)
-        client5 = network.get_client(verify=False)
-        client6 = network.get_client(max_redirects=10)
+        client1 = await network.get_client()
+        client2 = await network.get_client(verify=True)
+        client3 = await network.get_client(max_redirects=10)
+        client4 = await network.get_client(verify=True)
+        client5 = await network.get_client(verify=False)
+        client6 = await network.get_client(max_redirects=10)
 
 
         self.assertEqual(client1, client2)
         self.assertEqual(client1, client2)
         self.assertEqual(client1, client4)
         self.assertEqual(client1, client4)
@@ -107,7 +107,7 @@ class TestNetwork(SearxTestCase):
 
 
     async def test_aclose(self):
     async def test_aclose(self):
         network = Network(verify=True)
         network = Network(verify=True)
-        network.get_client()
+        await network.get_client()
         await network.aclose()
         await network.aclose()
 
 
     async def test_request(self):
     async def test_request(self):
@@ -211,7 +211,7 @@ class TestNetworkStreamRetries(SearxTestCase):
     async def test_retries_ok(self):
     async def test_retries_ok(self):
         with patch.object(httpx.AsyncClient, 'stream', new=TestNetworkStreamRetries.get_response_exception_then_200()):
         with patch.object(httpx.AsyncClient, 'stream', new=TestNetworkStreamRetries.get_response_exception_then_200()):
             network = Network(enable_http=True, retries=1, retry_on_http_error=403)
             network = Network(enable_http=True, retries=1, retry_on_http_error=403)
-            response = network.stream('GET', 'https://example.com/')
+            response = await network.stream('GET', 'https://example.com/')
             self.assertEqual(response.text, TestNetworkStreamRetries.TEXT)
             self.assertEqual(response.text, TestNetworkStreamRetries.TEXT)
             await network.aclose()
             await network.aclose()
 
 
@@ -219,7 +219,7 @@ class TestNetworkStreamRetries(SearxTestCase):
         with patch.object(httpx.AsyncClient, 'stream', new=TestNetworkStreamRetries.get_response_exception_then_200()):
         with patch.object(httpx.AsyncClient, 'stream', new=TestNetworkStreamRetries.get_response_exception_then_200()):
             network = Network(enable_http=True, retries=0, retry_on_http_error=403)
             network = Network(enable_http=True, retries=0, retry_on_http_error=403)
             with self.assertRaises(httpx.RequestError):
             with self.assertRaises(httpx.RequestError):
-                network.stream('GET', 'https://example.com/')
+                await network.stream('GET', 'https://example.com/')
             await network.aclose()
             await network.aclose()
 
 
     async def test_retries_exception(self):
     async def test_retries_exception(self):
@@ -234,6 +234,6 @@ class TestNetworkStreamRetries(SearxTestCase):
 
 
         with patch.object(httpx.AsyncClient, 'stream', new=stream):
         with patch.object(httpx.AsyncClient, 'stream', new=stream):
             network = Network(enable_http=True, retries=0, retry_on_http_error=403)
             network = Network(enable_http=True, retries=0, retry_on_http_error=403)
-            response = network.stream('GET', 'https://example.com/')
+            response = await network.stream('GET', 'https://example.com/')
             self.assertEqual(response.status_code, 403)
             self.assertEqual(response.status_code, 403)
             await network.aclose()
             await network.aclose()