1"""`PrincipleMLChecker`."""
2
3import copy
4import logging
5import os
6import pickle
7from collections import Counter
8from collections.abc import Iterable
9from typing import Any, Literal, Optional, Union, overload
10
11import idstools.rule
12import xgboost
13from pandas import DataFrame, Series
14from sklearn.metrics import f1_score, make_scorer, precision_score, recall_score
15from sklearn.model_selection import (
16 GridSearchCV,
17 RepeatedStratifiedKFold,
18 cross_val_score,
19)
20from sklearn.pipeline import Pipeline
21
22from suricata_check._version import SURICATA_CHECK_DIR
23from suricata_check.checkers.interface.checker import CheckerInterface
24from suricata_check.checkers.principle._utils import get_message
25from suricata_check.utils.checker import get_rule_option, get_rule_suboptions
26from suricata_check.utils.checker_typing import ISSUES_TYPE, Issue
27
28_PICKLE_PATH = os.path.join(SURICATA_CHECK_DIR, "data", "principle_ml_checker.pkl")
29N_JOBS = 8
30
31
32_logger = logging.getLogger(__name__)
33
34
35COUNT_COLUMNS = (
36 "flowbits.isset.count",
37 "flowbits.isntoset.count",
38 "flowint.isset.count",
39 "flowint.isntoset.count",
40 "xbits.isset.count",
41 "xbits.uisnotset.count",
42 "http.uri.count",
43 "http.method.count",
44 "dns.query.count",
45 "content.count",
46 "pcre.count",
47 "startswith.count",
48 "bsize.count",
49 "depth.count",
50 "urilen.count",
51 "flow.from_server.count",
52 "flow.to_server.count",
53 "flow.from_client.count",
54 "flow.to_client.count",
55)
56STRING_COLUMNS = ()
57DROPDOWN_COLUMNS = (
58 "proto",
59 "threshold.type",
60)
61NUMERICAL_COLUMNS = ("threshold.count",)
62SPLITTABLE_FEATURES = (
63 "metadata",
64 "flow",
65 "threshold",
66)
67MSG_KEYWORDS = ("Suspicious", "CVE", "Vulnerability", "Response")
68MSG_COLUMNS = ("msg.contains." + keyword for keyword in MSG_KEYWORDS)
69IP_KEYWORDS = ("$HOME_NET", "$HTTP_SERVERS", "$EXTERNAL_NET", "any")
70IP_COLUMNS = tuple(
71 ["source_addr.contains." + keyword for keyword in IP_KEYWORDS]
72 + ["dest_addr.contains." + keyword for keyword in IP_KEYWORDS]
73)
74
75
76PIPELINE = Pipeline(
77 [
78 (
79 "classify",
80 xgboost.XGBClassifier(),
81 )
82 ]
83)
84# https://shengyg.github.io/repository/machine%20learning/2017/02/25/Complete-Guide-to-Parameter-Tuning-xgboost.html
85PARAM_GRID: list[dict] = [
86 {
87 # Fixed parameters for problem / desired complexity
88 "classify__n_estimators": [1000],
89 "classify__objective": ["binary:logistic"],
90 ###
91 # Parameters to optimize
92 ## Learning rate
93 "classify__eta": [0.01, 0.1, 0.3],
94 ## Tree parameters
95 "classify__subsample": [1.0],
96 "classify__colsample_bytree": [0.25, 0.5, 0.75, 1.0],
97 "classify__scale_pos_weight": [0.1, 0.25, 0.5, 1.0, 2.0, 4.0, 10.0],
98 "classify__max_depth": [1, 3],
99 "classify__min_child_weight": [1],
100 "classify__gamma": [0, 0.1],
101 ## Regularization
102 "classify__lambda": [0, 0.01, 0.1],
103 "classify__alpha": [0, 0.01, 0.1],
104 },
105]
106
107PRECISION_WEIGHT = 10
108SCORER = make_scorer(
109 lambda y, y_pred: (PRECISION_WEIGHT + 1)
110 / (
111 PRECISION_WEIGHT / (precision_score(y, y_pred, zero_division=1) + 1e-10) # type: ignore reportArgumentType
112 + 1 / (recall_score(y, y_pred, zero_division=0) + 1e-10) # type: ignore reportArgumentType
113 )
114)
115SPLITTER = RepeatedStratifiedKFold(n_splits=2, n_repeats=10)
116GRIDSEARCHCV = GridSearchCV(
117 PIPELINE, PARAM_GRID, cv=SPLITTER, scoring=SCORER, error_score="raise", n_jobs=N_JOBS, verbose=1 # type: ignore reportArgumentType
118)
119
120
[docs]
121class PrincipleMLChecker(CheckerInterface):
122 """The `PrincipleChecker` contains several checks based on the Ruling the Unruly paper and target specificity and coverage.
123
124 Codes Q000-Q009 report on non-adherence to rule design principles similar to Q000-Q009.
125 Differently, they are the result of machine learning analysis of the rules.
126 """
127
128 count_columns = COUNT_COLUMNS
129 string_columns = STRING_COLUMNS
130 dropdown_columns = DROPDOWN_COLUMNS
131 numerical_columns = NUMERICAL_COLUMNS
132 splittable_features = SPLITTABLE_FEATURES
133 msg_keywords = MSG_KEYWORDS
134 msg_columns = MSG_COLUMNS
135 ip_keywords = IP_KEYWORDS
136 ip_columns = IP_COLUMNS
137
138 codes = {
139 "Q000": {"severity": logging.INFO},
140 "Q001": {"severity": logging.INFO},
141 "Q002": {"severity": logging.INFO},
142 "Q003": {"severity": logging.INFO},
143 "Q004": {"severity": logging.INFO},
144 "Q005": {"severity": logging.INFO},
145 }
146
147 enabled_by_default = (
148 False # Since the checker is relatively slow, it is disabled by default
149 )
150
151 _dtypes: Optional[dict[str, Any]] = None
152 _models: dict[str, Pipeline] = {}
153
154 def __new__(
155 cls: type["PrincipleMLChecker"],
156 filepath: Optional[str] = _PICKLE_PATH,
157 *args: tuple,
158 **kwargs: dict,
159 ) -> "PrincipleMLChecker":
160 """Returns a new or unpickled instance of the class."""
161 if filepath:
162 if os.path.exists(filepath):
163 with open(filepath, "rb") as f:
164 inst = pickle.load(f)
165
166 # BEGIN LEGACY CODE
167 # CAN BE REMOVED AFTER TRAINING NEW PKL
168 if hasattr(inst, "models"):
169 inst._models = inst.models # noqa: SLF001
170 inst._dtypes = inst.dtypes # noqa: SLF001
171 # END LEGACY CODE
172
173 if not inst.__class__.__name__ == cls.__name__:
174 _logger.error("Unpickled object is not of type %s", cls)
175 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType
176 elif not hasattr(inst, "_models") or len(inst._models) == 0:
177 _logger.error("Unpickled object does not have trained models")
178 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType
179 else:
180 if "include" in kwargs:
181 inst.include = kwargs["include"]
182 _logger.info("Unpickled object with trained models successfully")
183 else:
184 _logger.warning("No model found for PrincipleMLChecker at %s", filepath)
185 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType
186 else:
187 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType
188
189 return inst
190
[docs]
191 def __getnewargs__(
192 self: "PrincipleMLChecker",
193 ) -> tuple:
194 """Returns the arguments to be passed to the __new__ method when unpickling."""
195 return (None,)
196
197 def _check_rule(
198 self: "PrincipleMLChecker",
199 rule: idstools.rule.Rule,
200 ) -> ISSUES_TYPE:
201 issues: ISSUES_TYPE = []
202
203 if len(self._models) == 0:
204 return issues
205
206 for code, model in self._models.items():
207 if model.predict(self._get_features(rule, True))[0]:
208 issues.append(
209 Issue(
210 code=code,
211 message=get_message(code),
212 )
213 )
214
215 return issues
216
[docs]
217 def train( # noqa: C901
218 self: "PrincipleMLChecker",
219 df: DataFrame,
220 rule_col: str = "rule.rule",
221 principle_cols: dict[str, str] = {
222 "Q000": "labelled.no_proxy",
223 "Q001": "labelled.success",
224 "Q002": "labelled.thresholded",
225 "Q003": "labelled.exceptions",
226 "Q004": "labelled.generalized_match_content",
227 "Q005": "labelled.generalized_match_location",
228 },
229 reuse_models: bool = False,
230 ) -> None:
231 """Train several models for the checker to detect issues in rules.
232
233 The checker class with trained models is stored in a pickle file (`_PICKLE_PATH`).
234 """
235 self._dtypes = None
236 if not reuse_models:
237 self._models = {}
238
239 # Extract features and determine feature dtypes
240 X_train = self._get_train_df(df[rule_col]) # noqa: N806
241
242 for col in X_train.columns:
243 try:
244 X_train[col].var()
245 _logger.debug("Detected column: %s", col)
246 except:
247 _logger.error("Error with column %s", col)
248 _logger.error(X_train[col])
249 raise
250
251 # # Drop zero variance columns
252 X_train = X_train.drop( # noqa: N806
253 X_train.columns[(X_train.fillna(-1337).var(axis=0) <= 0)].to_list(), # type: ignore reportAttributeAccessIssue
254 axis=1,
255 )
256
257 # Drop columns with too few occurrences of possible values
258 for col in X_train.columns:
259 if (
260 not col.endswith(".count")
261 and not col.endswith(".num")
262 and not col.endswith(".len")
263 ):
264 if X_train[col].value_counts().min() <= 1:
265 X_train = X_train.drop( # noqa: N806
266 [col],
267 axis=1,
268 )
269
270 for col in X_train.columns:
271 try:
272 X_train[col].var()
273 _logger.info("Using column: %s", col)
274 except:
275 _logger.error("Error with column %s", col)
276 _logger.error(X_train[col])
277 raise
278
279 # Store used features and their dtypes
280 self._dtypes = X_train.dtypes.to_dict()
281 _logger.debug(self._dtypes)
282
283 # Redo feature extraction now that FE parameters are set
284 X_train = self._get_train_df(df[rule_col]) # noqa: N806
285
286 _logger.info(
287 "Training model with features: [%s]",
288 ", ".join([str(x) for x in X_train.columns]),
289 )
290
291 _logger.info(X_train)
292
293 for code, col in principle_cols.items():
294 y_true = df[col].to_numpy() == 0
295
296 if not reuse_models or code not in self._models:
297 # Train new model with grid search to find optimal parameters
298 gridsearchcv: GridSearchCV = copy.deepcopy(GRIDSEARCHCV)
299
300 gridsearchcv.fit(X_train, y_true)
301
302 _logger.info("Code %s params: %s", code, gridsearchcv.best_params_)
303 _logger.info(
304 "Code %s Weighted F1-score: %s", code, gridsearchcv.best_score_
305 )
306
307 self._models[code] = gridsearchcv.best_estimator_
308
309 precision = cross_val_score(
310 self._models[code],
311 X_train,
312 y_true,
313 scoring=make_scorer(precision_score, zero_division=0.0),
314 cv=SPLITTER,
315 n_jobs=N_JOBS,
316 ).mean()
317 recall = cross_val_score(
318 self._models[code],
319 X_train,
320 y_true,
321 scoring=make_scorer(recall_score, zero_division=0.0),
322 cv=SPLITTER,
323 n_jobs=N_JOBS,
324 ).mean()
325 f1 = cross_val_score(
326 self._models[code],
327 X_train,
328 y_true,
329 scoring=make_scorer(f1_score, zero_division=0.0),
330 cv=SPLITTER,
331 n_jobs=N_JOBS,
332 ).mean()
333 _logger.info("Code %s Precision score: %s", code, precision)
334 _logger.info("Code %s Recall score: %s", code, recall)
335 _logger.info("Code %s F1-score: %s", code, f1)
336
337 # Refit model with training data.
338 self._models[code].fit(X_train, y_true)
339
340 pickle.dump(self, open(_PICKLE_PATH, "wb"))
341
342 def _get_train_df(self: "PrincipleMLChecker", rules: Iterable[str]) -> DataFrame:
343 feature_vectors = []
344 for rule in rules:
345 parsed_rule = idstools.rule.parse(rule)
346 assert parsed_rule is not None
347 feature_vectors.append(self._get_features(parsed_rule, False))
348
349 return DataFrame(feature_vectors)
350
351 def _get_raw_features( # noqa: C901
352 self: "PrincipleMLChecker", rule: idstools.rule.Rule
353 ) -> Series:
354 d: dict[str, Optional[Union[str, int]]] = {
355 "proto": get_rule_option(rule, "proto")
356 }
357
358 options = rule["options"]
359
360 for option in options:
361 d[option["name"]] = option["value"]
362
363 counter = Counter([option["name"] for option in options])
364 for option, count in counter.items():
365 d[option + ".count"] = count
366
367 for option in options:
368 if option["name"] not in self.splittable_features:
369 continue
370
371 suboptions = [
372 {"name": k, "value": v}
373 for k, v in get_rule_suboptions(rule, option["name"], warn=False)
374 ]
375
376 if len(suboptions) == 0:
377 continue
378
379 for suboption in suboptions:
380 d[option["name"] + "." + suboption["name"]] = suboption["value"]
381
382 counter = Counter([suboption["name"] for suboption in suboptions])
383 for suboption, count in counter.items():
384 d[option["name"] + "." + suboption + ".count"] = count
385
386 msg = get_rule_option(rule, "msg")
387 assert msg is not None
388 msg = msg.lower()
389 for col, keyword in zip(self.msg_columns, self.msg_keywords):
390 d[col] = keyword.lower() in msg
391
392 source_addr = get_rule_option(rule, "source_addr")
393 assert source_addr is not None
394 source_addr = source_addr.lower()
395 for keyword in self.ip_keywords:
396 col = "source_addr.contains." + keyword
397 d[col] = keyword.lower() in source_addr
398
399 dest_addr = get_rule_option(rule, "dest_addr")
400 assert dest_addr is not None
401 dest_addr = dest_addr.lower()
402 for keyword in self.ip_keywords:
403 col = "dest_addr.contains." + keyword
404 d[col] = keyword.lower() in dest_addr
405
406 return Series(d)
407
408 def _preprocess_features(self: "PrincipleMLChecker", data: Series) -> Series:
409 original_cols: set[str] = set(data.index)
410
411 for col in self.string_columns:
412 if col not in data:
413 continue
414 data[col + ".len"] = len(data[col])
415 data = data.drop(col)
416
417 for col in self.dropdown_columns:
418 if col not in data:
419 continue
420 data[col + "." + data[col] + ".bool"] = 1
421 data = data.drop(col)
422
423 for col in self.numerical_columns:
424 if col not in data:
425 continue
426 data[col + ".num"] = float(data[col])
427 data = data.drop(col)
428
429 remaining_cols = (
430 original_cols
431 - set(self.count_columns)
432 - set(self.string_columns)
433 - set(self.dropdown_columns)
434 - set(self.numerical_columns)
435 - set(self.msg_columns)
436 - set(self.ip_columns)
437 )
438
439 for col in remaining_cols:
440 data = data.drop(col)
441
442 return data
443
444 @overload
445 def _get_features(
446 self: "PrincipleMLChecker", rule: idstools.rule.Rule, frame: Literal[True]
447 ) -> DataFrame:
448 pass
449
450 @overload
451 def _get_features(
452 self: "PrincipleMLChecker", rule: idstools.rule.Rule, frame: Literal[False]
453 ) -> Series:
454 pass
455
456 def _get_features_frame(self: "PrincipleMLChecker", features: Series) -> DataFrame:
457 features_frame = features.to_frame().transpose()
458
459 if self._dtypes is None:
460 return features_frame
461
462 for col, dtype in self._dtypes.items():
463 if features_frame.dtypes[col] != dtype:
464 features_frame[col] = features_frame[col].astype(dtype)
465
466 return features_frame
467
468 def _get_features(
469 self: "PrincipleMLChecker", rule: idstools.rule.Rule, frame: bool
470 ) -> Union[Series, DataFrame]:
471 features: Series = self._get_raw_features(rule)
472 features = self._preprocess_features(features)
473
474 features["custom.negated.count"] = rule["raw"].count(':!"')
475
476 if self._dtypes is None:
477 return features
478
479 for col, dtype in self._dtypes.items():
480 if col not in features:
481 if col.endswith(".count"):
482 features[col] = 0
483 elif col.endswith(".bool"):
484 features[col] = 0
485 elif col.endswith(".num"):
486 features[col] = -1
487 else:
488 _logger.error(
489 "Unsure how to handle missing feature %s of type %s",
490 col,
491 dtype,
492 )
493
494 features = features[list(self._dtypes.keys())] # type: ignore reportAssignmentType
495
496 if not frame:
497 return features
498
499 return self._get_features_frame(features)