Browse Source

plugins: refactor initialization

add a new function "init" call when the app starts.
The function can:
* return False to disable the plugin.
* modify the Flask app.
Alexandre Flament 3 years ago
parent
commit
2b4fef7118
5 changed files with 250 additions and 190 deletions
  1. 204 156
      searx/plugins/__init__.py
  2. 10 8
      searx/plugins/ahmia_filter.py
  3. 2 2
      searx/settings_defaults.py
  4. 22 20
      searx/webapp.py
  5. 12 4
      tests/unit/test_plugins.py

+ 204 - 156
searx/plugins/__init__.py

@@ -1,122 +1,38 @@
-'''
-searx is free software: you can redistribute it and/or modify
-it under the terms of the GNU Affero General Public License as published by
-the Free Software Foundation, either version 3 of the License, or
-(at your option) any later version.
-
-searx is distributed in the hope that it will be useful,
-but WITHOUT ANY WARRANTY; without even the implied warranty of
-MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-GNU Affero General Public License for more details.
-
-You should have received a copy of the GNU Affero General Public License
-along with searx. If not, see < http://www.gnu.org/licenses/ >.
-
-(C) 2015 by Adam Tauber, <asciimoo@gmail.com>
-'''
+# SPDX-License-Identifier: AGPL-3.0-or-later
+# lint: pylint
+# pylint: disable=missing-module-docstring, missing-class-docstring
 
 
+import sys
 from hashlib import sha256
 from hashlib import sha256
 from importlib import import_module
 from importlib import import_module
 from os import listdir, makedirs, remove, stat, utime
 from os import listdir, makedirs, remove, stat, utime
 from os.path import abspath, basename, dirname, exists, join
 from os.path import abspath, basename, dirname, exists, join
 from shutil import copyfile
 from shutil import copyfile
+from pkgutil import iter_modules
+from logging import getLogger
 
 
 from searx import logger, settings
 from searx import logger, settings
 
 
 
 
-logger = logger.getChild('plugins')
-
-from searx.plugins import (oa_doi_rewrite,
-                           ahmia_filter,
-                           hash_plugin,
-                           infinite_scroll,
-                           self_info,
-                           hostname_replace,
-                           search_on_category_select,
-                           tracker_url_remover,
-                           vim_hotkeys)
-
-required_attrs = (('name', str),
-                  ('description', str),
-                  ('default_on', bool))
-
-optional_attrs = (('js_dependencies', tuple),
-                  ('css_dependencies', tuple),
-                  ('preference_section', str))
-
-
-class Plugin():
-    default_on = False
-    name = 'Default plugin'
-    description = 'Default plugin description'
-
-
-class PluginStore():
-
-    def __init__(self):
-        self.plugins = []
-
-    def __iter__(self):
-        for plugin in self.plugins:
-            yield plugin
-
-    def register(self, *plugins, external=False):
-        if external:
-            plugins = load_external_plugins(plugins)
-        for plugin in plugins:
-            for plugin_attr, plugin_attr_type in required_attrs:
-                if not hasattr(plugin, plugin_attr):
-                    logger.critical('missing attribute "{0}", cannot load plugin: {1}'.format(plugin_attr, plugin))
-                    exit(3)
-                attr = getattr(plugin, plugin_attr)
-                if not isinstance(attr, plugin_attr_type):
-                    type_attr = str(type(attr))
-                    logger.critical(
-                        'attribute "{0}" is of type {2}, must be of type {3}, cannot load plugin: {1}'
-                        .format(plugin_attr, plugin, type_attr, plugin_attr_type)
-                    )
-                    exit(3)
-            for plugin_attr, plugin_attr_type in optional_attrs:
-                if not hasattr(plugin, plugin_attr) or not isinstance(getattr(plugin, plugin_attr), plugin_attr_type):
-                    setattr(plugin, plugin_attr, plugin_attr_type())
-            plugin.id = plugin.name.replace(' ', '_')
-            if not hasattr(plugin, 'preference_section'):
-                plugin.preference_section = 'general'
-            if plugin.preference_section == 'query':
-                for plugin_attr in ('query_keywords', 'query_examples'):
-                    if not hasattr(plugin, plugin_attr):
-                        logger.critical('missing attribute "{0}", cannot load plugin: {1}'.format(plugin_attr, plugin))
-                        exit(3)
-            self.plugins.append(plugin)
-
-    def call(self, ordered_plugin_list, plugin_type, request, *args, **kwargs):
-        ret = True
-        for plugin in ordered_plugin_list:
-            if hasattr(plugin, plugin_type):
-                ret = getattr(plugin, plugin_type)(request, *args, **kwargs)
-                if not ret:
-                    break
-
-        return ret
-
+logger = logger.getChild("plugins")
 
 
-def load_external_plugins(plugin_names):
-    plugins = []
-    for name in plugin_names:
-        logger.debug('loading plugin: {0}'.format(name))
-        try:
-            pkg = import_module(name)
-        except Exception as e:
-            logger.critical('failed to load plugin module {0}: {1}'.format(name, e))
-            exit(3)
+required_attrs = (
+    ("name", str),
+    ("description", str),
+    ("default_on", bool)
+)
 
 
-        pkg.__base_path = dirname(abspath(pkg.__file__))
+optional_attrs = (
+    ("js_dependencies", tuple),
+    ("css_dependencies", tuple),
+    ("preference_section", str),
+)
 
 
-        prepare_package_resources(pkg, name)
 
 
-        plugins.append(pkg)
-        logger.debug('plugin "{0}" loaded'.format(name))
-    return plugins
+def sha_sum(filename):
+    with open(filename, "rb") as f:
+        file_content_bytes = f.read()
+        return sha256(file_content_bytes).hexdigest()
 
 
 
 
 def sync_resource(base_path, resource_path, name, target_dir, plugin_dir):
 def sync_resource(base_path, resource_path, name, target_dir, plugin_dir):
@@ -130,74 +46,206 @@ def sync_resource(base_path, resource_path, name, target_dir, plugin_dir):
             # the HTTP server) do not change
             # the HTTP server) do not change
             dep_stat = stat(dep_path)
             dep_stat = stat(dep_path)
             utime(resource_path, ns=(dep_stat.st_atime_ns, dep_stat.st_mtime_ns))
             utime(resource_path, ns=(dep_stat.st_atime_ns, dep_stat.st_mtime_ns))
-        except:
-            logger.critical('failed to copy plugin resource {0} for plugin {1}'.format(file_name, name))
-            exit(3)
+        except IOError:
+            logger.critical(
+                "failed to copy plugin resource {0} for plugin {1}".format(
+                    file_name, name
+                )
+            )
+            sys.exit(3)
 
 
     # returning with the web path of the resource
     # returning with the web path of the resource
-    return join('plugins/external_plugins', plugin_dir, file_name)
+    return join("plugins/external_plugins", plugin_dir, file_name)
 
 
 
 
-def prepare_package_resources(pkg, name):
-    plugin_dir = 'plugin_' + name
-    target_dir = join(settings['ui']['static_path'], 'plugins/external_plugins', plugin_dir)
+def prepare_package_resources(plugin, plugin_module_name):
+    # pylint: disable=consider-using-generator
+    plugin_base_path = dirname(abspath(plugin.__file__))
+
+    plugin_dir = plugin_module_name
+    target_dir = join(
+        settings["ui"]["static_path"], "plugins/external_plugins", plugin_dir
+    )
     try:
     try:
         makedirs(target_dir, exist_ok=True)
         makedirs(target_dir, exist_ok=True)
-    except:
-        logger.critical('failed to create resource directory {0} for plugin {1}'.format(target_dir, name))
-        exit(3)
+    except IOError:
+        logger.critical(
+            "failed to create resource directory {0} for plugin {1}".format(
+                target_dir, plugin_module_name
+            )
+        )
+        sys.exit(3)
 
 
     resources = []
     resources = []
 
 
-    if hasattr(pkg, 'js_dependencies'):
-        resources.extend(map(basename, pkg.js_dependencies))
-        pkg.js_dependencies = tuple([
-            sync_resource(pkg.__base_path, x, name, target_dir, plugin_dir)
-            for x in pkg.js_dependencies
-        ])
-    if hasattr(pkg, 'css_dependencies'):
-        resources.extend(map(basename, pkg.css_dependencies))
-        pkg.css_dependencies = tuple([
-            sync_resource(pkg.__base_path, x, name, target_dir, plugin_dir)
-            for x in pkg.css_dependencies
-        ])
+    if hasattr(plugin, "js_dependencies"):
+        resources.extend(map(basename, plugin.js_dependencies))
+        plugin.js_dependencies = tuple(
+            [
+                sync_resource(
+                    plugin_base_path, x, plugin_module_name, target_dir, plugin_dir
+                )
+                for x in plugin.js_dependencies
+            ]
+        )
+    if hasattr(plugin, "css_dependencies"):
+        resources.extend(map(basename, plugin.css_dependencies))
+        plugin.css_dependencies = tuple(
+            [
+                sync_resource(
+                    plugin_base_path, x, plugin_module_name, target_dir, plugin_dir
+                )
+                for x in plugin.css_dependencies
+            ]
+        )
 
 
     for f in listdir(target_dir):
     for f in listdir(target_dir):
         if basename(f) not in resources:
         if basename(f) not in resources:
             resource_path = join(target_dir, basename(f))
             resource_path = join(target_dir, basename(f))
             try:
             try:
                 remove(resource_path)
                 remove(resource_path)
-            except:
-                logger.critical('failed to remove unused resource file {0} for plugin {1}'.format(resource_path, name))
-                exit(3)
+            except IOError:
+                logger.critical(
+                    "failed to remove unused resource file {0} for plugin {1}".format(
+                        resource_path, plugin_module_name
+                    )
+                )
+                sys.exit(3)
 
 
 
 
-def sha_sum(filename):
-    with open(filename, "rb") as f:
-        file_content_bytes = f.read()
-        return sha256(file_content_bytes).hexdigest()
+def load_plugin(plugin_module_name, external):
+    # pylint: disable=too-many-branches
+    try:
+        plugin = import_module(plugin_module_name)
+    except (
+        SyntaxError,
+        KeyboardInterrupt,
+        SystemExit,
+        SystemError,
+        ImportError,
+        RuntimeError,
+    ) as e:
+        logger.critical("%s: fatal exception", plugin_module_name, exc_info=e)
+        sys.exit(3)
+    except BaseException:
+        logger.exception("%s: exception while loading, the plugin is disabled", plugin_module_name)
+        return None
+
+    # difference with searx: use module name instead of the user name
+    plugin.id = plugin_module_name
+
+    #
+    plugin.logger = getLogger(plugin_module_name)
+
+    for plugin_attr, plugin_attr_type in required_attrs:
+        if not hasattr(plugin, plugin_attr):
+            logger.critical(
+                '%s: missing attribute "%s", cannot load plugin', plugin, plugin_attr
+            )
+            sys.exit(3)
+        attr = getattr(plugin, plugin_attr)
+        if not isinstance(attr, plugin_attr_type):
+            type_attr = str(type(attr))
+            logger.critical(
+                '{1}: attribute "{0}" is of type {2}, must be of type {3}, cannot load plugin'.format(
+                    plugin, plugin_attr, type_attr, plugin_attr_type
+                )
+            )
+            sys.exit(3)
+
+    for plugin_attr, plugin_attr_type in optional_attrs:
+        if not hasattr(plugin, plugin_attr) or not isinstance(
+            getattr(plugin, plugin_attr), plugin_attr_type
+        ):
+            setattr(plugin, plugin_attr, plugin_attr_type())
+
+    if not hasattr(plugin, "preference_section"):
+        plugin.preference_section = "general"
+
+    # query plugin
+    if plugin.preference_section == "query":
+        for plugin_attr in ("query_keywords", "query_examples"):
+            if not hasattr(plugin, plugin_attr):
+                logger.critical(
+                    'missing attribute "{0}", cannot load plugin: {1}'.format(
+                        plugin_attr, plugin
+                    )
+                )
+                sys.exit(3)
+
+    if settings.get("enabled_plugins"):
+        # searx compatibility: plugin.name in settings['enabled_plugins']
+        plugin.default_on = (
+            plugin.name in settings["enabled_plugins"]
+            or plugin.id in settings["enabled_plugins"]
+        )
+
+    # copy ressources if this is an external plugin
+    if external:
+        prepare_package_resources(plugin, plugin_module_name)
+
+    logger.debug("%s: loaded", plugin_module_name)
+
+    return plugin
+
+
+def load_and_initialize_plugin(plugin_module_name, external, init_args):
+    plugin = load_plugin(plugin_module_name, external)
+    if plugin and hasattr(plugin, 'init'):
+        try:
+            return plugin if plugin.init(*init_args) else None
+        except Exception:  # pylint: disable=broad-except
+            plugin.logger.exception(
+                "Exception while calling init, the plugin is disabled"
+            )
+            return None
+    return plugin
+
+
+class PluginStore:
+    def __init__(self):
+        self.plugins = []
+
+    def __iter__(self):
+        for plugin in self.plugins:
+            yield plugin
+
+    def register(self, plugin):
+        self.plugins.append(plugin)
+
+    def call(self, ordered_plugin_list, plugin_type, *args, **kwargs):
+        # pylint: disable=no-self-use
+        ret = True
+        for plugin in ordered_plugin_list:
+            if hasattr(plugin, plugin_type):
+                try:
+                    ret = getattr(plugin, plugin_type)(*args, **kwargs)
+                    if not ret:
+                        break
+                except Exception:  # pylint: disable=broad-except
+                    plugin.logger.exception("Exception while calling %s", plugin_type)
+        return ret
 
 
 
 
 plugins = PluginStore()
 plugins = PluginStore()
-plugins.register(oa_doi_rewrite)
-plugins.register(hash_plugin)
-plugins.register(infinite_scroll)
-plugins.register(self_info)
-plugins.register(hostname_replace)
-plugins.register(search_on_category_select)
-plugins.register(tracker_url_remover)
-plugins.register(vim_hotkeys)
-# load external plugins
-if settings['plugins']:
-    plugins.register(*settings['plugins'], external=True)
-
-if settings['enabled_plugins']:
-    for plugin in plugins:
-        if plugin.name in settings['enabled_plugins']:
-            plugin.default_on = True
-        else:
-            plugin.default_on = False
-
-# load tor specific plugins
-if settings['outgoing']['using_tor_proxy']:
-    plugins.register(ahmia_filter)
+
+
+def plugin_module_names():
+    yield_plugins = set()
+
+    # embedded plugins
+    for module_name in iter_modules(path=[dirname(__file__)]):
+        yield (module_name, False)
+        yield_plugins.add(module_name)
+    # external plugins
+    for module_name in settings['plugins']:
+        if module_name not in yield_plugins:
+            yield (module_name, True)
+            yield_plugins.add(module_name)
+
+
+def initialize(app):
+    for module_name, external in plugin_module_names():
+        plugin = load_and_initialize_plugin(__name__ + "." + module_name.name, external, (app, settings))
+        if plugin:
+            plugins.register(plugin)

+ 10 - 8
searx/plugins/ahmia_filter.py

@@ -13,15 +13,17 @@ preference_section = 'onions'
 ahmia_blacklist = None
 ahmia_blacklist = None
 
 
 
 
-def get_ahmia_blacklist():
-    global ahmia_blacklist
-    if not ahmia_blacklist:
-        ahmia_blacklist = ahmia_blacklist_loader()
-    return ahmia_blacklist
-
-
 def on_result(request, search, result):
 def on_result(request, search, result):
     if not result.get('is_onion') or not result.get('parsed_url'):
     if not result.get('is_onion') or not result.get('parsed_url'):
         return True
         return True
     result_hash = md5(result['parsed_url'].hostname.encode()).hexdigest()
     result_hash = md5(result['parsed_url'].hostname.encode()).hexdigest()
-    return result_hash not in get_ahmia_blacklist()
+    return result_hash not in ahmia_blacklist
+
+
+def init(app, settings):
+    global ahmia_blacklist  # pylint: disable=global-statement
+    if not settings['outgoing']['using_tor_proxy']:
+        # disable the plugin
+        return False
+    ahmia_blacklist = ahmia_blacklist_loader()
+    return True

+ 2 - 2
searx/settings_defaults.py

@@ -200,8 +200,8 @@ SCHEMA = {
         'networks': {
         'networks': {
         },
         },
     },
     },
-    'plugins': SettingsValue((None, list), None),
-    'enabled_plugins': SettingsValue(list, []),
+    'plugins': SettingsValue(list, []),
+    'enabled_plugins': SettingsValue((None, list), None),
     'checker': {
     'checker': {
         'off_when_debug': SettingsValue(bool, True),
         'off_when_debug': SettingsValue(bool, True),
     },
     },

+ 22 - 20
searx/webapp.py

@@ -85,7 +85,7 @@ from searx.utils import (
 )
 )
 from searx.version import VERSION_STRING, GIT_URL, GIT_BRANCH
 from searx.version import VERSION_STRING, GIT_URL, GIT_BRANCH
 from searx.query import RawTextQuery
 from searx.query import RawTextQuery
-from searx.plugins import plugins
+from searx.plugins import plugins, initialize as plugin_initialize
 from searx.plugins.oa_doi_rewrite import get_doi_resolver
 from searx.plugins.oa_doi_rewrite import get_doi_resolver
 from searx.preferences import (
 from searx.preferences import (
     Preferences,
     Preferences,
@@ -158,25 +158,6 @@ app.jinja_env.lstrip_blocks = True
 app.jinja_env.add_extension('jinja2.ext.loopcontrols')  # pylint: disable=no-member
 app.jinja_env.add_extension('jinja2.ext.loopcontrols')  # pylint: disable=no-member
 app.secret_key = settings['server']['secret_key']
 app.secret_key = settings['server']['secret_key']
 
 
-# see https://flask.palletsprojects.com/en/1.1.x/cli/
-# True if "FLASK_APP=searx/webapp.py FLASK_ENV=development flask run"
-flask_run_development = (
-    os.environ.get("FLASK_APP") is not None
-    and os.environ.get("FLASK_ENV") == 'development'
-    and is_flask_run_cmdline()
-)
-
-# True if reload feature is activated of werkzeug, False otherwise (including uwsgi, etc..)
-#  __name__ != "__main__" if searx.webapp is imported (make test, make docs, uwsgi...)
-# see run() at the end of this file : searx_debug activates the reload feature.
-werkzeug_reloader = flask_run_development or (searx_debug and __name__ == "__main__")
-
-# initialize the engines except on the first run of the werkzeug server.
-if (not werkzeug_reloader
-    or (werkzeug_reloader
-        and os.environ.get("WERKZEUG_RUN_MAIN") == "true") ):
-    search_initialize(enable_checker=True)
-
 babel = Babel(app)
 babel = Babel(app)
 
 
 # used when translating category names
 # used when translating category names
@@ -1351,6 +1332,27 @@ def page_not_found(_e):
     return render('404.html'), 404
     return render('404.html'), 404
 
 
 
 
+# see https://flask.palletsprojects.com/en/1.1.x/cli/
+# True if "FLASK_APP=searx/webapp.py FLASK_ENV=development flask run"
+flask_run_development = (
+    os.environ.get("FLASK_APP") is not None
+    and os.environ.get("FLASK_ENV") == 'development'
+    and is_flask_run_cmdline()
+)
+
+# True if reload feature is activated of werkzeug, False otherwise (including uwsgi, etc..)
+#  __name__ != "__main__" if searx.webapp is imported (make test, make docs, uwsgi...)
+# see run() at the end of this file : searx_debug activates the reload feature.
+werkzeug_reloader = flask_run_development or (searx_debug and __name__ == "__main__")
+
+# initialize the engines except on the first run of the werkzeug server.
+if (not werkzeug_reloader
+    or (werkzeug_reloader
+        and os.environ.get("WERKZEUG_RUN_MAIN") == "true") ):
+    plugin_initialize(app)
+    search_initialize(enable_checker=True)
+
+
 def run():
 def run():
     logger.debug(
     logger.debug(
         'starting webserver on %s:%s',
         'starting webserver on %s:%s',

+ 12 - 4
tests/unit/test_plugins.py

@@ -10,6 +10,12 @@ def get_search_mock(query, **kwargs):
                 result_container=Mock(answers=dict()))
                 result_container=Mock(answers=dict()))
 
 
 
 
+class PluginMock():
+    default_on = False
+    name = 'Default plugin'
+    description = 'Default plugin description'
+
+
 class PluginStoreTest(SearxTestCase):
 class PluginStoreTest(SearxTestCase):
 
 
     def test_PluginStore_init(self):
     def test_PluginStore_init(self):
@@ -18,14 +24,14 @@ class PluginStoreTest(SearxTestCase):
 
 
     def test_PluginStore_register(self):
     def test_PluginStore_register(self):
         store = plugins.PluginStore()
         store = plugins.PluginStore()
-        testplugin = plugins.Plugin()
+        testplugin = PluginMock()
         store.register(testplugin)
         store.register(testplugin)
 
 
         self.assertTrue(len(store.plugins) == 1)
         self.assertTrue(len(store.plugins) == 1)
 
 
     def test_PluginStore_call(self):
     def test_PluginStore_call(self):
         store = plugins.PluginStore()
         store = plugins.PluginStore()
-        testplugin = plugins.Plugin()
+        testplugin = PluginMock()
         store.register(testplugin)
         store.register(testplugin)
         setattr(testplugin, 'asdf', Mock())
         setattr(testplugin, 'asdf', Mock())
         request = Mock()
         request = Mock()
@@ -40,8 +46,9 @@ class PluginStoreTest(SearxTestCase):
 class SelfIPTest(SearxTestCase):
 class SelfIPTest(SearxTestCase):
 
 
     def test_PluginStore_init(self):
     def test_PluginStore_init(self):
+        plugin = plugins.load_and_initialize_plugin('searx.plugins.self_info', False, (None, {}))
         store = plugins.PluginStore()
         store = plugins.PluginStore()
-        store.register(plugins.self_info)
+        store.register(plugin)
 
 
         self.assertTrue(len(store.plugins) == 1)
         self.assertTrue(len(store.plugins) == 1)
 
 
@@ -89,7 +96,8 @@ class HashPluginTest(SearxTestCase):
 
 
     def test_PluginStore_init(self):
     def test_PluginStore_init(self):
         store = plugins.PluginStore()
         store = plugins.PluginStore()
-        store.register(plugins.hash_plugin)
+        plugin = plugins.load_and_initialize_plugin('searx.plugins.hash_plugin', False, (None, {}))
+        store.register(plugin)
 
 
         self.assertTrue(len(store.plugins) == 1)
         self.assertTrue(len(store.plugins) == 1)