1"""The `suricata_check.suricata_check` module contains the command line utility and the main program logic."""
2
3import logging
4import logging.handlers
5import pkgutil
6from collections.abc import Sequence
7from functools import lru_cache
8
9from suricata_check.checkers.interface import CheckerInterface
10from suricata_check.checkers.interface.dummy import DummyChecker
11from suricata_check.utils.checker_typing import (
12 get_all_subclasses,
13)
14from suricata_check.utils.regex import get_regex_provider
15
16_regex_provider = get_regex_provider()
17
18_logger = logging.getLogger(__name__)
19
20# Global variable to check if extensions have already been imported in case get_checkers() is called multiple times.
21suricata_check_extensions_imported = False
22
23
24def _import_extensions() -> None:
25 global suricata_check_extensions_imported # noqa: PLW0603
26 if suricata_check_extensions_imported is True:
27 return
28
29 for module in pkgutil.iter_modules():
30 if module.name.startswith("suricata_check_"):
31 try:
32 imported_module = __import__(module.name)
33 _logger.info(
34 "Detected and successfully imported suricata-check extension %s with version %s.",
35 module.name.replace("_", "-"),
36 getattr(imported_module, "__version__"),
37 )
38 except ImportError:
39 _logger.warning(
40 "Detected potential suricata-check extension %s but failed to import it.",
41 module.name.replace("_", "-"),
42 )
43 suricata_check_extensions_imported = True
44
45
[docs]
46@lru_cache(maxsize=1)
47def get_checkers(
48 include: Sequence[str] = (".*",),
49 exclude: Sequence[str] = (),
50 issue_severity: int = logging.INFO,
51) -> Sequence[CheckerInterface]:
52 """Auto discovers all available checkers that implement the CheckerInterface.
53
54 Returns:
55 A list of available checkers that implement the CheckerInterface.
56
57 """
58 # Check for extensions and try to import them
59 _import_extensions()
60
61 checkers: list[CheckerInterface] = []
62 for checker in get_all_subclasses(CheckerInterface):
63 if checker.__name__ == DummyChecker.__name__:
64 continue
65
66 # Initialize DummyCheckers to retrieve error messages.
67 if issubclass(checker, DummyChecker):
68 checker()
69
70 enabled, relevant_codes = __get_checker_enabled(
71 checker, include, exclude, issue_severity
72 )
73
74 if enabled:
75 checkers.append(checker(include=relevant_codes))
76
77 else:
78 _logger.info("Checker %s is disabled.", checker.__name__)
79
80 _logger.info(
81 "Discovered and enabled checkers: [%s]",
82 ", ".join([c.__class__.__name__ for c in checkers]),
83 )
84 if len(checkers) == 0:
85 _logger.warning(
86 "No checkers were enabled. Check the include and exclude arguments."
87 )
88
89 # Perform a uniqueness check on the codes emmitted by the checkers
90 for checker1 in checkers:
91 for checker2 in checkers:
92 if checker1 == checker2:
93 continue
94 if not set(checker1.codes).isdisjoint(checker2.codes):
95 msg = f"Checker {checker1.__class__.__name__} and {checker2.__class__.__name__} have overlapping codes."
96 _logger.error(msg)
97
98 return sorted(checkers, key=lambda x: x.__class__.__name__)
99
100
101def __get_checker_enabled(
102 checker: type[CheckerInterface],
103 include: Sequence[str],
104 exclude: Sequence[str],
105 issue_severity: int,
106) -> tuple[bool, set[str]]:
107 enabled = checker.enabled_by_default
108
109 # If no include regexes are provided, include all by default
110 if len(include) == 0:
111 relevant_codes = set(checker.codes.keys())
112 else:
113 # If include regexes are provided, include all codes that match any of these regexes
114 relevant_codes = set()
115
116 for regex in include:
117 relevant_codes.update(
118 set(
119 filter(
120 lambda code: _regex_provider.compile("^" + regex + "$").match(
121 code
122 )
123 is not None,
124 checker.codes.keys(),
125 )
126 )
127 )
128
129 if len(relevant_codes) > 0:
130 enabled = True
131
132 # Now remove the codes that are excluded according to any of the provided exclude regexes
133 for regex in exclude:
134 relevant_codes = set(
135 filter(
136 lambda code: _regex_provider.compile("^" + regex + "$").match(code)
137 is None,
138 relevant_codes,
139 )
140 )
141
142 # Now filter out irrelevant codes based on severity
143 relevant_codes = set(
144 filter(
145 lambda code: checker.codes[code]["severity"] >= issue_severity,
146 relevant_codes,
147 )
148 )
149
150 if len(relevant_codes) == 0:
151 enabled = False
152
153 return enabled, relevant_codes