1"""The `suricata_check._checkers` module contains functionality for selecting checkers."""
2
3import logging
4import logging.handlers
5import pkgutil
6from collections.abc import Iterable, Sequence
7from functools import lru_cache
8from typing import TypeVar
9
10from suricata_check.checkers.interface import CheckerInterface
11from suricata_check.checkers.interface._dummy import DummyChecker
12from suricata_check.utils.regex_provider import get_regex_provider
13
14_regex_provider = get_regex_provider()
15
16_logger = logging.getLogger(__name__)
17
18# Global variable to check if extensions have already been imported in case get_checkers() is called multiple times.
19suricata_check_extensions_imported = False
20
21
22def _import_extensions() -> None:
23 global suricata_check_extensions_imported # noqa: PLW0603
24 if suricata_check_extensions_imported is True:
25 return
26
27 for module in pkgutil.iter_modules():
28 if module.name.startswith("suricata_check_"):
29 try:
30 imported_module = __import__(module.name)
31 _logger.info(
32 "Detected and successfully imported suricata-check extension %s with version %s.",
33 module.name.replace("_", "-"),
34 getattr(imported_module, "__version__"),
35 )
36 except ImportError:
37 _logger.warning(
38 "Detected potential suricata-check extension %s but failed to import it.",
39 module.name.replace("_", "-"),
40 )
41 suricata_check_extensions_imported = True
42
43
[docs]
44@lru_cache(maxsize=1)
45def get_checkers(
46 include: Sequence[str] = (".*",),
47 exclude: Sequence[str] = (),
48 issue_severity: int = logging.INFO,
49) -> Sequence[CheckerInterface]:
50 """Auto discovers all available checkers that implement the CheckerInterface.
51
52 Returns:
53 A list of available checkers that implement the CheckerInterface.
54
55 """
56 # Check for extensions and try to import them
57 _import_extensions()
58
59 checkers: list[CheckerInterface] = []
60 for checker in __get_all_subclasses(CheckerInterface):
61 if checker.__name__ == DummyChecker.__name__:
62 continue
63
64 # Initialize DummyCheckers to retrieve error messages.
65 if issubclass(checker, DummyChecker):
66 checker()
67
68 enabled, relevant_codes = __get_checker_enabled(
69 checker,
70 include,
71 exclude,
72 issue_severity,
73 )
74
75 if enabled:
76 checkers.append(checker(include=relevant_codes))
77
78 else:
79 _logger.info(
80 "Checker %s is disabled.\
81Issues from this checker are not counted towards reported number of suppressed issues.",
82 checker.__name__,
83 )
84
85 _logger.info(
86 "Discovered and enabled checkers: [%s]",
87 ", ".join([c.__class__.__name__ for c in checkers]),
88 )
89 if len(checkers) == 0:
90 _logger.warning(
91 "No checkers were enabled. Check the include and exclude arguments.",
92 )
93
94 # Perform a uniqueness check on the codes emmitted by the checkers
95 for checker1 in checkers:
96 for checker2 in checkers:
97 if checker1 == checker2:
98 continue
99 if not set(checker1.codes).isdisjoint(checker2.codes):
100 msg = f"Checker {checker1.__class__.__name__} and {checker2.__class__.__name__} have overlapping codes."
101 _logger.error(msg)
102
103 return sorted(checkers, key=lambda x: x.__class__.__name__)
104
105
106def __get_checker_enabled(
107 checker: type[CheckerInterface],
108 include: Sequence[str],
109 exclude: Sequence[str],
110 issue_severity: int,
111) -> tuple[bool, set[str]]:
112 enabled = checker.enabled_by_default
113
114 # If no include regexes are provided, include all by default
115 if len(include) == 0:
116 relevant_codes = set(checker.codes.keys())
117 else:
118 # If include regexes are provided, include all codes that match any of these regexes
119 relevant_codes = set()
120
121 for regex in include:
122 relevant_codes.update(
123 set(
124 filter(
125 lambda code: _regex_provider.compile("^" + regex + "$").match(
126 code,
127 )
128 is not None,
129 checker.codes.keys(),
130 ),
131 ),
132 )
133
134 if len(relevant_codes) > 0:
135 enabled = True
136
137 # Now remove the codes that are excluded according to any of the provided exclude regexes
138 for regex in exclude:
139 relevant_codes = set(
140 filter(
141 lambda code: _regex_provider.compile("^" + regex + "$").match(code)
142 is None,
143 relevant_codes,
144 ),
145 )
146
147 # Now filter out irrelevant codes based on severity
148 relevant_codes = set(
149 filter(
150 lambda code: checker.codes[code]["severity"] >= issue_severity,
151 relevant_codes,
152 ),
153 )
154
155 if len(relevant_codes) == 0:
156 enabled = False
157
158 return enabled, relevant_codes
159
160
161Cls = TypeVar("Cls")
162
163
164def __get_all_subclasses(cls: type[Cls]) -> Iterable[type[Cls]]:
165 """Returns all class types that subclass the provided type."""
166 all_subclasses = []
167
168 for subclass in cls.__subclasses__():
169 all_subclasses.append(subclass)
170 all_subclasses.extend(__get_all_subclasses(subclass))
171
172 return all_subclasses