Browse Source

[mod] searx.utils: more typing

Alexandre Flament 3 years ago
parent
commit
2d5929cc59
1 changed files with 22 additions and 12 deletions
  1. 22 12
      searx/utils.py

+ 22 - 12
searx/utils.py

@@ -6,6 +6,8 @@
 """
 """
 import re
 import re
 import importlib
 import importlib
+import importlib.util
+import types
 
 
 from typing import Optional, Union, Any, Set, List, Dict, MutableMapping, Tuple, Callable
 from typing import Optional, Union, Any, Set, List, Dict, MutableMapping, Tuple, Callable
 from numbers import Number
 from numbers import Number
@@ -45,8 +47,8 @@ _STORAGE_UNIT_VALUE: Dict[str, int] = {
     'KiB': 1000,
     'KiB': 1000,
 }
 }
 
 
-_XPATH_CACHE = {}
-_LANG_TO_LC_CACHE = {}
+_XPATH_CACHE: Dict[str, XPath] = {}
+_LANG_TO_LC_CACHE: Dict[str, Dict[str, str]] = {}
 
 
 
 
 class _NotSetClass:  # pylint: disable=too-few-public-methods
 class _NotSetClass:  # pylint: disable=too-few-public-methods
@@ -150,7 +152,7 @@ def html_to_text(html_str: str) -> str:
     return s.get_text()
     return s.get_text()
 
 
 
 
-def extract_text(xpath_results, allow_none: bool = False):
+def extract_text(xpath_results, allow_none: bool = False) -> Optional[str]:
     """Extract text from a lxml result
     """Extract text from a lxml result
 
 
     * if xpath_results is list, extract the text from each result and concat the list
     * if xpath_results is list, extract the text from each result and concat the list
@@ -264,7 +266,9 @@ def extract_url(xpath_results, base_url) -> str:
         raise ValueError('Empty url resultset')
         raise ValueError('Empty url resultset')
 
 
     url = extract_text(xpath_results)
     url = extract_text(xpath_results)
-    return normalize_url(url, base_url)
+    if url:
+        return normalize_url(url, base_url)
+    raise ValueError('URL not found')
 
 
 
 
 def dict_subset(dictionnary: MutableMapping, properties: Set[str]) -> Dict:
 def dict_subset(dictionnary: MutableMapping, properties: Set[str]) -> Dict:
@@ -366,7 +370,7 @@ def _get_lang_to_lc_dict(lang_list: List[str]) -> Dict[str, str]:
 
 
 # babel's get_global contains all sorts of miscellaneous locale and territory related data
 # babel's get_global contains all sorts of miscellaneous locale and territory related data
 # see get_global in: https://github.com/python-babel/babel/blob/master/babel/core.py
 # see get_global in: https://github.com/python-babel/babel/blob/master/babel/core.py
-def _get_from_babel(lang_code, key: str):
+def _get_from_babel(lang_code: str, key: str):
     match = get_global(key).get(lang_code.replace('-', '_'))
     match = get_global(key).get(lang_code.replace('-', '_'))
     # for some keys, such as territory_aliases, match may be a list
     # for some keys, such as territory_aliases, match may be a list
     if isinstance(match, str):
     if isinstance(match, str):
@@ -374,7 +378,7 @@ def _get_from_babel(lang_code, key: str):
     return match
     return match
 
 
 
 
-def _match_language(lang_code, lang_list=[], custom_aliases={}) -> Optional[str]:  # pylint: disable=W0102
+def _match_language(lang_code: str, lang_list=[], custom_aliases={}) -> Optional[str]:  # pylint: disable=W0102
     """auxiliary function to match lang_code in lang_list"""
     """auxiliary function to match lang_code in lang_list"""
     # replace language code with a custom alias if necessary
     # replace language code with a custom alias if necessary
     if lang_code in custom_aliases:
     if lang_code in custom_aliases:
@@ -396,10 +400,12 @@ def _match_language(lang_code, lang_list=[], custom_aliases={}) -> Optional[str]
             return new_code
             return new_code
 
 
     # try to get the any supported country for this language
     # try to get the any supported country for this language
-    return _get_lang_to_lc_dict(lang_list).get(lang_code, None)
+    return _get_lang_to_lc_dict(lang_list).get(lang_code)
 
 
 
 
-def match_language(locale_code, lang_list=[], custom_aliases={}, fallback='en-US') -> str:  # pylint: disable=W0102
+def match_language(  # pylint: disable=W0102
+    locale_code, lang_list=[], custom_aliases={}, fallback: Optional[str] = 'en-US'
+) -> Optional[str]:
     """get the language code from lang_list that best matches locale_code"""
     """get the language code from lang_list that best matches locale_code"""
     # try to get language from given locale_code
     # try to get language from given locale_code
     language = _match_language(locale_code, lang_list, custom_aliases)
     language = _match_language(locale_code, lang_list, custom_aliases)
@@ -437,12 +443,16 @@ def match_language(locale_code, lang_list=[], custom_aliases={}, fallback='en-US
     return language or fallback
     return language or fallback
 
 
 
 
-def load_module(filename: str, module_dir: str):
+def load_module(filename: str, module_dir: str) -> types.ModuleType:
     modname = splitext(filename)[0]
     modname = splitext(filename)[0]
-    filepath = join(module_dir, filename)
+    modpath = join(module_dir, filename)
     # and https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
     # and https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
-    spec = importlib.util.spec_from_file_location(modname, filepath)
+    spec = importlib.util.spec_from_file_location(modname, modpath)
+    if not spec:
+        raise ValueError(f"Error loading '{modpath}' module")
     module = importlib.util.module_from_spec(spec)
     module = importlib.util.module_from_spec(spec)
+    if not spec.loader:
+        raise ValueError(f"Error loading '{modpath}' module")
     spec.loader.exec_module(module)
     spec.loader.exec_module(module)
     return module
     return module
 
 
@@ -477,7 +487,7 @@ def ecma_unescape(string: str) -> str:
     return string
     return string
 
 
 
 
-def get_string_replaces_function(replaces: Dict[str, str]) -> Callable:
+def get_string_replaces_function(replaces: Dict[str, str]) -> Callable[[str], str]:
     rep = {re.escape(k): v for k, v in replaces.items()}
     rep = {re.escape(k): v for k, v in replaces.items()}
     pattern = re.compile("|".join(rep.keys()))
     pattern = re.compile("|".join(rep.keys()))