Browse Source

Merge pull request #773 from not-my-profile/typing

More typing
Martin Fischer 3 years ago
parent
commit
96a1f79c6d

+ 3 - 1
searx/network/client.py

@@ -4,7 +4,9 @@
 
 
 import asyncio
 import asyncio
 import logging
 import logging
+from ssl import SSLContext
 import threading
 import threading
+from typing import Any, Dict
 
 
 import httpx
 import httpx
 from httpx_socks import AsyncProxyTransport
 from httpx_socks import AsyncProxyTransport
@@ -23,7 +25,7 @@ else:
 
 
 logger = logger.getChild('searx.network.client')
 logger = logger.getChild('searx.network.client')
 LOOP = None
 LOOP = None
-SSLCONTEXTS = {}
+SSLCONTEXTS: Dict[Any, SSLContext] = {}
 TRANSPORT_KWARGS = {
 TRANSPORT_KWARGS = {
     'trust_env': False,
     'trust_env': False,
 }
 }

+ 2 - 1
searx/network/network.py

@@ -7,6 +7,7 @@ import atexit
 import asyncio
 import asyncio
 import ipaddress
 import ipaddress
 from itertools import cycle
 from itertools import cycle
+from typing import Dict
 
 
 import httpx
 import httpx
 
 
@@ -16,7 +17,7 @@ from .client import new_client, get_loop, AsyncHTTPTransportNoHttp
 
 
 logger = logger.getChild('network')
 logger = logger.getChild('network')
 DEFAULT_NAME = '__DEFAULT__'
 DEFAULT_NAME = '__DEFAULT__'
-NETWORKS = {}
+NETWORKS: Dict[str, 'Network'] = {}
 # requests compatibility when reading proxy settings from settings.yml
 # requests compatibility when reading proxy settings from settings.yml
 PROXY_PATTERN_MAPPING = {
 PROXY_PATTERN_MAPPING = {
     'http': 'http://',
     'http': 'http://',

+ 4 - 1
searx/plugins/__init__.py

@@ -10,7 +10,7 @@ from os.path import abspath, basename, dirname, exists, join
 from shutil import copyfile
 from shutil import copyfile
 from pkgutil import iter_modules
 from pkgutil import iter_modules
 from logging import getLogger
 from logging import getLogger
-from typing import List
+from typing import List, Tuple
 
 
 from searx import logger, settings
 from searx import logger, settings
 
 
@@ -22,6 +22,9 @@ class Plugin:  # pylint: disable=too-few-public-methods
     name: str
     name: str
     description: str
     description: str
     default_on: bool
     default_on: bool
+    js_dependencies: Tuple[str]
+    css_dependencies: Tuple[str]
+    preference_section: str
 
 
 
 
 logger = logger.getChild("plugins")
 logger = logger.getChild("plugins")

+ 21 - 12
searx/results.py

@@ -2,7 +2,9 @@ import re
 from collections import defaultdict
 from collections import defaultdict
 from operator import itemgetter
 from operator import itemgetter
 from threading import RLock
 from threading import RLock
+from typing import List, NamedTuple, Set
 from urllib.parse import urlparse, unquote
 from urllib.parse import urlparse, unquote
+
 from searx import logger
 from searx import logger
 from searx.engines import engines
 from searx.engines import engines
 from searx.metrics import histogram_observe, counter_add, count_error
 from searx.metrics import histogram_observe, counter_add, count_error
@@ -137,6 +139,18 @@ def result_score(result):
     return sum((occurences * weight) / position for position in result['positions'])
     return sum((occurences * weight) / position for position in result['positions'])
 
 
 
 
+class Timing(NamedTuple):
+    engine: str
+    total: float
+    load: float
+
+
+class UnresponsiveEngine(NamedTuple):
+    engine: str
+    error_type: str
+    suspended: bool
+
+
 class ResultContainer:
 class ResultContainer:
     """docstring for ResultContainer"""
     """docstring for ResultContainer"""
 
 
@@ -168,8 +182,8 @@ class ResultContainer:
         self.engine_data = defaultdict(dict)
         self.engine_data = defaultdict(dict)
         self._closed = False
         self._closed = False
         self.paging = False
         self.paging = False
-        self.unresponsive_engines = set()
-        self.timings = []
+        self.unresponsive_engines: Set[UnresponsiveEngine] = set()
+        self.timings: List[Timing] = []
         self.redirect_url = None
         self.redirect_url = None
         self.on_result = lambda _: True
         self.on_result = lambda _: True
         self._lock = RLock()
         self._lock = RLock()
@@ -401,17 +415,12 @@ class ResultContainer:
             return 0
             return 0
         return resultnum_sum / len(self._number_of_results)
         return resultnum_sum / len(self._number_of_results)
 
 
-    def add_unresponsive_engine(self, engine_name, error_type, error_message=None, suspended=False):
+    def add_unresponsive_engine(self, engine_name: str, error_type: str, suspended: bool = False):
         if engines[engine_name].display_error_messages:
         if engines[engine_name].display_error_messages:
-            self.unresponsive_engines.add((engine_name, error_type, error_message, suspended))
-
-    def add_timing(self, engine_name, engine_time, page_load_time):
-        timing = {
-            'engine': engines[engine_name].shortcut,
-            'total': engine_time,
-            'load': page_load_time,
-        }
-        self.timings.append(timing)
+            self.unresponsive_engines.add(UnresponsiveEngine(engine_name, error_type, suspended))
+
+    def add_timing(self, engine_name: str, engine_time: float, page_load_time: float):
+        self.timings.append(Timing(engine_name, total=engine_time, load=page_load_time))
 
 
     def get_timings(self):
     def get_timings(self):
         return self.timings
         return self.timings

+ 2 - 1
searx/search/processors/__init__.py

@@ -15,6 +15,7 @@ __all__ = [
 ]
 ]
 
 
 import threading
 import threading
+from typing import Dict
 
 
 from searx import logger
 from searx import logger
 from searx import engines
 from searx import engines
@@ -26,7 +27,7 @@ from .online_currency import OnlineCurrencyProcessor
 from .abstract import EngineProcessor
 from .abstract import EngineProcessor
 
 
 logger = logger.getChild('search.processors')
 logger = logger.getChild('search.processors')
-PROCESSORS = {}
+PROCESSORS: Dict[str, EngineProcessor] = {}
 """Cache request processores, stored by *engine-name* (:py:func:`initialize`)"""
 """Cache request processores, stored by *engine-name* (:py:func:`initialize`)"""
 
 
 
 

+ 3 - 2
searx/search/processors/abstract.py

@@ -8,6 +8,7 @@
 import threading
 import threading
 from abc import abstractmethod, ABC
 from abc import abstractmethod, ABC
 from timeit import default_timer
 from timeit import default_timer
+from typing import Dict, Union
 
 
 from searx import settings, logger
 from searx import settings, logger
 from searx.engines import engines
 from searx.engines import engines
@@ -17,7 +18,7 @@ from searx.exceptions import SearxEngineAccessDeniedException, SearxEngineRespon
 from searx.utils import get_engine_from_settings
 from searx.utils import get_engine_from_settings
 
 
 logger = logger.getChild('searx.search.processor')
 logger = logger.getChild('searx.search.processor')
-SUSPENDED_STATUS = {}
+SUSPENDED_STATUS: Dict[Union[int, str], 'SuspendedStatus'] = {}
 
 
 
 
 class SuspendedStatus:
 class SuspendedStatus:
@@ -61,7 +62,7 @@ class EngineProcessor(ABC):
 
 
     __slots__ = 'engine', 'engine_name', 'lock', 'suspended_status', 'logger'
     __slots__ = 'engine', 'engine_name', 'lock', 'suspended_status', 'logger'
 
 
-    def __init__(self, engine, engine_name):
+    def __init__(self, engine, engine_name: str):
         self.engine = engine
         self.engine = engine
         self.engine_name = engine_name
         self.engine_name = engine_name
         self.logger = engines[engine_name].logger
         self.logger = engines[engine_name].logger

+ 42 - 27
searx/webapp.py

@@ -14,8 +14,11 @@ from datetime import datetime, timedelta
 from timeit import default_timer
 from timeit import default_timer
 from html import escape
 from html import escape
 from io import StringIO
 from io import StringIO
+import typing
+from typing import List, Dict, Iterable
 
 
 import urllib
 import urllib
+import urllib.parse
 from urllib.parse import urlencode
 from urllib.parse import urlencode
 
 
 import httpx
 import httpx
@@ -28,7 +31,6 @@ import flask
 
 
 from flask import (
 from flask import (
     Flask,
     Flask,
-    request,
     render_template,
     render_template,
     url_for,
     url_for,
     Response,
     Response,
@@ -55,6 +57,7 @@ from searx import (
     searx_debug,
     searx_debug,
 )
 )
 from searx.data import ENGINE_DESCRIPTIONS
 from searx.data import ENGINE_DESCRIPTIONS
+from searx.results import Timing, UnresponsiveEngine
 from searx.settings_defaults import OUTPUT_FORMATS
 from searx.settings_defaults import OUTPUT_FORMATS
 from searx.settings_loader import get_default_settings_path
 from searx.settings_loader import get_default_settings_path
 from searx.exceptions import SearxParameterException
 from searx.exceptions import SearxParameterException
@@ -89,7 +92,7 @@ from searx.utils import (
 )
 )
 from searx.version import VERSION_STRING, GIT_URL, GIT_BRANCH
 from searx.version import VERSION_STRING, GIT_URL, GIT_BRANCH
 from searx.query import RawTextQuery
 from searx.query import RawTextQuery
-from searx.plugins import plugins, initialize as plugin_initialize
+from searx.plugins import Plugin, plugins, initialize as plugin_initialize
 from searx.plugins.oa_doi_rewrite import get_doi_resolver
 from searx.plugins.oa_doi_rewrite import get_doi_resolver
 from searx.preferences import (
 from searx.preferences import (
     Preferences,
     Preferences,
@@ -224,6 +227,21 @@ exception_classname_to_text = {
 _flask_babel_get_translations = flask_babel.get_translations
 _flask_babel_get_translations = flask_babel.get_translations
 
 
 
 
+class ExtendedRequest(flask.Request):
+    """This class is never initialized and only used for type checking."""
+
+    preferences: Preferences
+    errors: List[str]
+    user_plugins: List[Plugin]
+    form: Dict[str, str]
+    start_time: float
+    render_time: float
+    timings: List[Timing]
+
+
+request = typing.cast(ExtendedRequest, flask.request)
+
+
 def _get_translations():
 def _get_translations():
     if has_request_context() and request.form.get('use-translation') == 'oc':
     if has_request_context() and request.form.get('use-translation') == 'oc':
         babel_ext = flask_babel.current_app.extensions['babel']
         babel_ext = flask_babel.current_app.extensions['babel']
@@ -321,7 +339,7 @@ def code_highlighter(codelines, language=None):
     return html_code
     return html_code
 
 
 
 
-def get_current_theme_name(override=None):
+def get_current_theme_name(override: str = None) -> str:
     """Returns theme name.
     """Returns theme name.
 
 
     Checks in this order:
     Checks in this order:
@@ -337,14 +355,14 @@ def get_current_theme_name(override=None):
     return theme_name
     return theme_name
 
 
 
 
-def get_result_template(theme_name, template_name):
+def get_result_template(theme_name: str, template_name: str):
     themed_path = theme_name + '/result_templates/' + template_name
     themed_path = theme_name + '/result_templates/' + template_name
     if themed_path in result_templates:
     if themed_path in result_templates:
         return themed_path
         return themed_path
     return 'result_templates/' + template_name
     return 'result_templates/' + template_name
 
 
 
 
-def url_for_theme(endpoint, override_theme=None, **values):
+def url_for_theme(endpoint: str, override_theme: str = None, **values):
     if endpoint == 'static' and values.get('filename'):
     if endpoint == 'static' and values.get('filename'):
         theme_name = get_current_theme_name(override=override_theme)
         theme_name = get_current_theme_name(override=override_theme)
         filename_with_theme = "themes/{}/{}".format(theme_name, values['filename'])
         filename_with_theme = "themes/{}/{}".format(theme_name, values['filename'])
@@ -354,7 +372,7 @@ def url_for_theme(endpoint, override_theme=None, **values):
     return url
     return url
 
 
 
 
-def proxify(url):
+def proxify(url: str):
     if url.startswith('//'):
     if url.startswith('//'):
         url = 'https:' + url
         url = 'https:' + url
 
 
@@ -369,7 +387,7 @@ def proxify(url):
     return '{0}?{1}'.format(settings['result_proxy']['url'], urlencode(url_params))
     return '{0}?{1}'.format(settings['result_proxy']['url'], urlencode(url_params))
 
 
 
 
-def image_proxify(url):
+def image_proxify(url: str):
 
 
     if url.startswith('//'):
     if url.startswith('//'):
         url = 'https:' + url
         url = 'https:' + url
@@ -405,7 +423,7 @@ def get_translations():
     }
     }
 
 
 
 
-def _get_enable_categories(all_categories):
+def _get_enable_categories(all_categories: Iterable[str]):
     disabled_engines = request.preferences.engines.get_disabled()
     disabled_engines = request.preferences.engines.get_disabled()
     enabled_categories = set(
     enabled_categories = set(
         # pylint: disable=consider-using-dict-items
         # pylint: disable=consider-using-dict-items
@@ -417,14 +435,14 @@ def _get_enable_categories(all_categories):
     return [x for x in all_categories if x in enabled_categories]
     return [x for x in all_categories if x in enabled_categories]
 
 
 
 
-def get_pretty_url(parsed_url):
+def get_pretty_url(parsed_url: urllib.parse.ParseResult):
     path = parsed_url.path
     path = parsed_url.path
     path = path[:-1] if len(path) > 0 and path[-1] == '/' else path
     path = path[:-1] if len(path) > 0 and path[-1] == '/' else path
     path = path.replace("/", " › ")
     path = path.replace("/", " › ")
     return [parsed_url.scheme + "://" + parsed_url.netloc, path]
     return [parsed_url.scheme + "://" + parsed_url.netloc, path]
 
 
 
 
-def render(template_name, override_theme=None, **kwargs):
+def render(template_name: str, override_theme: str = None, **kwargs):
     # values from the HTTP requests
     # values from the HTTP requests
     kwargs['endpoint'] = 'results' if 'q' in kwargs else request.endpoint
     kwargs['endpoint'] = 'results' if 'q' in kwargs else request.endpoint
     kwargs['cookies'] = request.cookies
     kwargs['cookies'] = request.cookies
@@ -552,7 +570,7 @@ def pre_request():
 
 
 
 
 @app.after_request
 @app.after_request
-def add_default_headers(response):
+def add_default_headers(response: flask.Response):
     # set default http headers
     # set default http headers
     for header, value in settings['server']['default_http_headers'].items():
     for header, value in settings['server']['default_http_headers'].items():
         if header in response.headers:
         if header in response.headers:
@@ -562,29 +580,28 @@ def add_default_headers(response):
 
 
 
 
 @app.after_request
 @app.after_request
-def post_request(response):
+def post_request(response: flask.Response):
     total_time = default_timer() - request.start_time
     total_time = default_timer() - request.start_time
     timings_all = [
     timings_all = [
         'total;dur=' + str(round(total_time * 1000, 3)),
         'total;dur=' + str(round(total_time * 1000, 3)),
         'render;dur=' + str(round(request.render_time * 1000, 3)),
         'render;dur=' + str(round(request.render_time * 1000, 3)),
     ]
     ]
     if len(request.timings) > 0:
     if len(request.timings) > 0:
-        timings = sorted(request.timings, key=lambda v: v['total'])
+        timings = sorted(request.timings, key=lambda t: t.total)
         timings_total = [
         timings_total = [
-            'total_' + str(i) + '_' + v['engine'] + ';dur=' + str(round(v['total'] * 1000, 3))
-            for i, v in enumerate(timings)
+            'total_' + str(i) + '_' + t.engine + ';dur=' + str(round(t.total * 1000, 3)) for i, t in enumerate(timings)
         ]
         ]
         timings_load = [
         timings_load = [
-            'load_' + str(i) + '_' + v['engine'] + ';dur=' + str(round(v['load'] * 1000, 3))
-            for i, v in enumerate(timings)
-            if v.get('load')
+            'load_' + str(i) + '_' + t.engine + ';dur=' + str(round(t.load * 1000, 3))
+            for i, t in enumerate(timings)
+            if t.load
         ]
         ]
         timings_all = timings_all + timings_total + timings_load
         timings_all = timings_all + timings_total + timings_load
     response.headers.add('Server-Timing', ', '.join(timings_all))
     response.headers.add('Server-Timing', ', '.join(timings_all))
     return response
     return response
 
 
 
 
-def index_error(output_format, error_message):
+def index_error(output_format: str, error_message: str):
     if output_format == 'json':
     if output_format == 'json':
         return Response(json.dumps({'error': error_message}), mimetype='application/json')
         return Response(json.dumps({'error': error_message}), mimetype='application/json')
     if output_format == 'csv':
     if output_format == 'csv':
@@ -828,23 +845,21 @@ def search():
     )
     )
 
 
 
 
-def __get_translated_errors(unresponsive_engines):
+def __get_translated_errors(unresponsive_engines: Iterable[UnresponsiveEngine]):
     translated_errors = []
     translated_errors = []
 
 
     # make a copy unresponsive_engines to avoid "RuntimeError: Set changed size
     # make a copy unresponsive_engines to avoid "RuntimeError: Set changed size
     # during iteration" it happens when an engine modifies the ResultContainer
     # during iteration" it happens when an engine modifies the ResultContainer
     # after the search_multiple_requests method has stopped waiting
     # after the search_multiple_requests method has stopped waiting
 
 
-    for unresponsive_engine in list(unresponsive_engines):
-        error_user_text = exception_classname_to_text.get(unresponsive_engine[1])
+    for unresponsive_engine in unresponsive_engines:
+        error_user_text = exception_classname_to_text.get(unresponsive_engine.error_type)
         if not error_user_text:
         if not error_user_text:
             error_user_text = exception_classname_to_text[None]
             error_user_text = exception_classname_to_text[None]
         error_msg = gettext(error_user_text)
         error_msg = gettext(error_user_text)
-        if unresponsive_engine[2]:
-            error_msg = "{} {}".format(error_msg, unresponsive_engine[2])
-        if unresponsive_engine[3]:
+        if unresponsive_engine.suspended:
             error_msg = gettext('Suspended') + ': ' + error_msg
             error_msg = gettext('Suspended') + ': ' + error_msg
-        translated_errors.append((unresponsive_engine[0], error_msg))
+        translated_errors.append((unresponsive_engine.engine, error_msg))
 
 
     return sorted(translated_errors, key=lambda e: e[0])
     return sorted(translated_errors, key=lambda e: e[0])
 
 
@@ -1060,7 +1075,7 @@ def preferences():
     )
     )
 
 
 
 
-def _is_selected_language_supported(engine, preferences):  # pylint: disable=redefined-outer-name
+def _is_selected_language_supported(engine, preferences: Preferences):  # pylint: disable=redefined-outer-name
     language = preferences.get_value('language')
     language = preferences.get_value('language')
     if language == 'all':
     if language == 'all':
         return True
         return True

+ 5 - 1
tests/unit/test_webapp.py

@@ -3,6 +3,7 @@
 import json
 import json
 from urllib.parse import ParseResult
 from urllib.parse import ParseResult
 from mock import Mock
 from mock import Mock
+from searx.results import Timing
 
 
 import searx.search.processors
 import searx.search.processors
 from searx.search import Search
 from searx.search import Search
@@ -46,7 +47,10 @@ class ViewsTestCase(SearxTestCase):
             },
             },
         ]
         ]
 
 
-        timings = [{'engine': 'startpage', 'total': 0.8, 'load': 0.7}, {'engine': 'youtube', 'total': 0.9, 'load': 0.6}]
+        timings = [
+            Timing(engine='startpage', total=0.8, load=0.7),
+            Timing(engine='youtube', total=0.9, load=0.6),
+        ]
 
 
         def search_mock(search_self, *args):
         def search_mock(search_self, *args):
             search_self.result_container = Mock(
             search_self.result_container = Mock(