1"""`PrincipleMLChecker`."""
2
3import copy
4import logging
5import os
6import pickle
7from collections import Counter
8from collections.abc import Iterable
9from typing import Any, Literal, Optional, Union, overload
10
11import idstools.rule
12import xgboost
13from pandas import DataFrame, Series
14from sklearn.metrics import f1_score, make_scorer, precision_score, recall_score
15from sklearn.model_selection import (
16 GridSearchCV,
17 RepeatedStratifiedKFold,
18 cross_val_score,
19)
20from sklearn.pipeline import Pipeline
21
22from suricata_check._version import SURICATA_CHECK_DIR
23from suricata_check.checkers.interface.checker import CheckerInterface
24from suricata_check.checkers.principle._utils import get_message
25from suricata_check.utils.checker import get_rule_option, get_rule_suboptions
26from suricata_check.utils.checker_typing import ISSUES_TYPE, Issue
27
28_PICKLE_PATH = os.path.join(SURICATA_CHECK_DIR, "data", "principle_ml_checker.pkl")
29N_JOBS = 8
30
31
32_logger = logging.getLogger(__name__)
33
34
35COUNT_COLUMNS = (
36 "flowbits.isset.count",
37 "flowbits.isntoset.count",
38 "flowint.isset.count",
39 "flowint.isntoset.count",
40 "xbits.isset.count",
41 "xbits.uisnotset.count",
42 "http.uri.count",
43 "http.method.count",
44 "dns.query.count",
45 "content.count",
46 "pcre.count",
47 "startswith.count",
48 "bsize.count",
49 "depth.count",
50 "urilen.count",
51 "flow.from_server.count",
52 "flow.to_server.count",
53 "flow.from_client.count",
54 "flow.to_client.count",
55)
56STRING_COLUMNS = ()
57DROPDOWN_COLUMNS = (
58 "proto",
59 "threshold.type",
60)
61NUMERICAL_COLUMNS = ("threshold.count",)
62SPLITTABLE_FEATURES = (
63 "metadata",
64 "flow",
65 "threshold",
66)
67MSG_KEYWORDS = ("Suspicious", "CVE", "Vulnerability", "Response")
68MSG_COLUMNS = ("msg.contains." + keyword for keyword in MSG_KEYWORDS)
69IP_KEYWORDS = ("$HOME_NET", "$HTTP_SERVERS", "$EXTERNAL_NET", "any")
70IP_COLUMNS = tuple(
71 ["source_addr.contains." + keyword for keyword in IP_KEYWORDS]
72 + ["dest_addr.contains." + keyword for keyword in IP_KEYWORDS]
73)
74
75
76PIPELINE = Pipeline(
77 [
78 (
79 "classify",
80 xgboost.XGBClassifier(),
81 )
82 ]
83)
84# https://shengyg.github.io/repository/machine%20learning/2017/02/25/Complete-Guide-to-Parameter-Tuning-xgboost.html
85PARAM_GRID: list[dict] = [
86 {
87 # Fixed parameters for problem / desired complexity
88 "classify__n_estimators": [1000],
89 "classify__objective": ["binary:logistic"],
90 ###
91 # Parameters to optimize
92 ## Learning rate
93 "classify__eta": [0.01, 0.1, 0.3],
94 ## Tree parameters
95 "classify__subsample": [1.0],
96 "classify__colsample_bytree": [0.25, 0.5, 0.75, 1.0],
97 "classify__scale_pos_weight": [0.1, 0.25, 0.5, 1.0, 2.0, 4.0, 10.0],
98 "classify__max_depth": [1, 3],
99 "classify__min_child_weight": [1],
100 "classify__gamma": [0, 0.1],
101 ## Regularization
102 "classify__lambda": [0, 0.01, 0.1],
103 "classify__alpha": [0, 0.01, 0.1],
104 },
105]
106
107PRECISION_WEIGHT = 10
108SCORER = make_scorer(
109 lambda y, y_pred: (PRECISION_WEIGHT + 1)
110 / (
111 PRECISION_WEIGHT / (precision_score(y, y_pred, zero_division=1) + 1e-10) # type: ignore reportArgumentType
112 + 1 / (recall_score(y, y_pred, zero_division=0) + 1e-10) # type: ignore reportArgumentType
113 )
114)
115SPLITTER = RepeatedStratifiedKFold(n_splits=2, n_repeats=10)
116GRIDSEARCHCV = GridSearchCV(
117 PIPELINE, PARAM_GRID, cv=SPLITTER, scoring=SCORER, error_score="raise", n_jobs=N_JOBS, verbose=1 # type: ignore reportArgumentType
118)
119
120
[docs]
121class PrincipleMLChecker(CheckerInterface):
122 """The `PrincipleChecker` contains several checks based on the Ruling the Unruly paper and target specificity and coverage.
123
124 Codes Q000-Q009 report on non-adherence to rule design principles similar to Q000-Q009.
125 Differently, they are the result of machine learning analysis of the rules.
126 """
127
128 count_columns = COUNT_COLUMNS
129 string_columns = STRING_COLUMNS
130 dropdown_columns = DROPDOWN_COLUMNS
131 numerical_columns = NUMERICAL_COLUMNS
132 splittable_features = SPLITTABLE_FEATURES
133 msg_keywords = MSG_KEYWORDS
134 msg_columns = MSG_COLUMNS
135 ip_keywords = IP_KEYWORDS
136 ip_columns = IP_COLUMNS
137
138 codes = {
139 "Q000": {"severity": logging.INFO},
140 "Q001": {"severity": logging.INFO},
141 "Q002": {"severity": logging.INFO},
142 "Q003": {"severity": logging.INFO},
143 "Q004": {"severity": logging.INFO},
144 "Q005": {"severity": logging.INFO},
145 }
146
147 enabled_by_default = (
148 False # Since the checker is relatively slow, it is disabled by default
149 )
150
151 _dtypes: Optional[dict[str, Any]] = None
152 _models: dict[str, Pipeline] = {}
153
154 def __new__(
155 cls: type["PrincipleMLChecker"],
156 filepath: Optional[str] = _PICKLE_PATH,
157 *args: tuple,
158 **kwargs: dict,
159 ) -> "PrincipleMLChecker":
160 """Returns a new or unpickled instance of the class."""
161 if filepath:
162 if os.path.exists(filepath):
163 with open(filepath, "rb") as f:
164 inst = pickle.load(f)
165
166 # BEGIN LEGACY CODE
167 # CAN BE REMOVED AFTER TRAINING NEW PKL
168 if hasattr(inst, "models"):
169 inst._models = inst.models # noqa: SLF001
170 inst._dtypes = inst.dtypes # noqa: SLF001
171 # END LEGACY CODE
172
173 if not inst.__class__.__name__ == cls.__name__:
174 _logger.error("Unpickled object is not of type %s", cls)
175 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType
176 elif (
177 not hasattr(inst, "_models")
178 or len(
179 inst._models # noqa: SLF001 type: ignore reportAttributeAccess
180 )
181 == 0
182 ):
183 _logger.error("Unpickled object does not have trained models")
184 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType
185 else:
186 if "include" in kwargs:
187 inst.include = kwargs["include"]
188 _logger.info("Unpickled object with trained models successfully")
189 else:
190 _logger.warning("No model found for PrincipleMLChecker at %s", filepath)
191 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType
192 else:
193 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType
194
195 return inst
196
[docs]
197 def __getnewargs__(
198 self: "PrincipleMLChecker",
199 ) -> tuple:
200 """Returns the arguments to be passed to the __new__ method when unpickling."""
201 return (None,)
202
203 def _check_rule(
204 self: "PrincipleMLChecker",
205 rule: idstools.rule.Rule,
206 ) -> ISSUES_TYPE:
207 issues: ISSUES_TYPE = []
208
209 if len(self._models) == 0:
210 return issues
211
212 for code, model in self._models.items():
213 if model.predict(self._get_features(rule, True))[0]:
214 issues.append(
215 Issue(
216 code=code,
217 message=get_message(code),
218 )
219 )
220
221 return issues
222
[docs]
223 def train( # noqa: C901
224 self: "PrincipleMLChecker",
225 df: DataFrame,
226 rule_col: str = "rule.rule",
227 principle_cols: dict[str, str] = {
228 "Q000": "labelled.no_proxy",
229 "Q001": "labelled.success",
230 "Q002": "labelled.thresholded",
231 "Q003": "labelled.exceptions",
232 "Q004": "labelled.generalized_match_content",
233 "Q005": "labelled.generalized_match_location",
234 },
235 reuse_models: bool = False,
236 ) -> None:
237 """Train several models for the checker to detect issues in rules.
238
239 The checker class with trained models is stored in a pickle file (`_PICKLE_PATH`).
240 """
241 self._dtypes = None
242 if not reuse_models:
243 self._models = {}
244
245 # Extract features and determine feature dtypes
246 X_train = self._get_train_df(df[rule_col]) # noqa: N806
247
248 for col in X_train.columns:
249 try:
250 X_train[col].var()
251 _logger.debug("Detected column: %s", col)
252 except:
253 _logger.error("Error with column %s", col)
254 _logger.error(X_train[col])
255 raise
256
257 # # Drop zero variance columns
258 X_train = X_train.drop( # noqa: N806
259 X_train.columns[(X_train.fillna(-1337).var(axis=0) <= 0)].to_list(), # type: ignore reportAttributeAccessIssue
260 axis=1,
261 )
262
263 # Drop columns with too few occurrences of possible values
264 for col in X_train.columns:
265 if (
266 not col.endswith(".count")
267 and not col.endswith(".num")
268 and not col.endswith(".len")
269 ):
270 if X_train[col].value_counts().min() <= 1:
271 X_train = X_train.drop( # noqa: N806
272 [col],
273 axis=1,
274 )
275
276 for col in X_train.columns:
277 try:
278 X_train[col].var()
279 _logger.info("Using column: %s", col)
280 except:
281 _logger.error("Error with column %s", col)
282 _logger.error(X_train[col])
283 raise
284
285 # Store used features and their dtypes
286 self._dtypes = X_train.dtypes.to_dict()
287 _logger.debug(self._dtypes)
288
289 # Redo feature extraction now that FE parameters are set
290 X_train = self._get_train_df(df[rule_col]) # noqa: N806
291
292 _logger.info(
293 "Training model with features: [%s]",
294 ", ".join([str(x) for x in X_train.columns]),
295 )
296
297 _logger.info(X_train)
298
299 for code, col in principle_cols.items():
300 y_true = df[col].to_numpy() == 0
301
302 if not reuse_models or code not in self._models:
303 # Train new model with grid search to find optimal parameters
304 gridsearchcv: GridSearchCV = copy.deepcopy(GRIDSEARCHCV)
305
306 gridsearchcv.fit(X_train, y_true)
307
308 _logger.info("Code %s params: %s", code, gridsearchcv.best_params_)
309 _logger.info(
310 "Code %s Weighted F1-score: %s", code, gridsearchcv.best_score_
311 )
312
313 self._models[code] = gridsearchcv.best_estimator_
314
315 precision = cross_val_score(
316 self._models[code],
317 X_train,
318 y_true,
319 scoring=make_scorer(precision_score, zero_division=0.0),
320 cv=SPLITTER,
321 n_jobs=N_JOBS,
322 ).mean()
323 recall = cross_val_score(
324 self._models[code],
325 X_train,
326 y_true,
327 scoring=make_scorer(recall_score, zero_division=0.0),
328 cv=SPLITTER,
329 n_jobs=N_JOBS,
330 ).mean()
331 f1 = cross_val_score(
332 self._models[code],
333 X_train,
334 y_true,
335 scoring=make_scorer(f1_score, zero_division=0.0),
336 cv=SPLITTER,
337 n_jobs=N_JOBS,
338 ).mean()
339 _logger.info("Code %s Precision score: %s", code, precision)
340 _logger.info("Code %s Recall score: %s", code, recall)
341 _logger.info("Code %s F1-score: %s", code, f1)
342
343 # Refit model with training data.
344 self._models[code].fit(X_train, y_true)
345
346 pickle.dump(self, open(_PICKLE_PATH, "wb"))
347
348 def _get_train_df(self: "PrincipleMLChecker", rules: Iterable[str]) -> DataFrame:
349 feature_vectors = []
350 for rule in rules:
351 parsed_rule = idstools.rule.parse(rule)
352 feature_vectors.append(self._get_features(parsed_rule, False))
353
354 return DataFrame(feature_vectors)
355
356 def _get_raw_features( # noqa: C901
357 self: "PrincipleMLChecker", rule: idstools.rule.Rule
358 ) -> Series:
359 d: dict[str, Optional[Union[str, int]]] = {
360 "proto": get_rule_option(rule, "proto")
361 }
362
363 options = rule["options"]
364
365 for option in options:
366 d[option["name"]] = option["value"]
367
368 counter = Counter([option["name"] for option in options])
369 for option, count in counter.items():
370 d[option + ".count"] = count
371
372 for option in options:
373 if option["name"] not in self.splittable_features:
374 continue
375
376 suboptions = [
377 {"name": k, "value": v}
378 for k, v in get_rule_suboptions(rule, option["name"], warn=False)
379 ]
380
381 if len(suboptions) == 0:
382 continue
383
384 for suboption in suboptions:
385 d[option["name"] + "." + suboption["name"]] = suboption["value"]
386
387 counter = Counter([suboption["name"] for suboption in suboptions])
388 for suboption, count in counter.items():
389 d[option["name"] + "." + suboption + ".count"] = count
390
391 msg = get_rule_option(rule, "msg")
392 assert msg is not None
393 msg = msg.lower()
394 for col, keyword in zip(self.msg_columns, self.msg_keywords):
395 d[col] = keyword.lower() in msg
396
397 source_addr = get_rule_option(rule, "source_addr")
398 assert source_addr is not None
399 source_addr = source_addr.lower()
400 for keyword in self.ip_keywords:
401 col = "source_addr.contains." + keyword
402 d[col] = keyword.lower() in source_addr
403
404 dest_addr = get_rule_option(rule, "dest_addr")
405 assert dest_addr is not None
406 dest_addr = dest_addr.lower()
407 for keyword in self.ip_keywords:
408 col = "dest_addr.contains." + keyword
409 d[col] = keyword.lower() in dest_addr
410
411 return Series(d)
412
413 def _preprocess_features(self: "PrincipleMLChecker", data: Series) -> Series:
414 original_cols: set[str] = set(data.index)
415
416 for col in self.string_columns:
417 if col not in data:
418 continue
419 data[col + ".len"] = len(data[col])
420 data = data.drop(col)
421
422 for col in self.dropdown_columns:
423 if col not in data:
424 continue
425 data[col + "." + data[col] + ".bool"] = 1
426 data = data.drop(col)
427
428 for col in self.numerical_columns:
429 if col not in data:
430 continue
431 data[col + ".num"] = float(data[col]) # type: ignore reportArgumentType
432 data = data.drop(col)
433
434 remaining_cols = (
435 original_cols
436 - set(self.count_columns)
437 - set(self.string_columns)
438 - set(self.dropdown_columns)
439 - set(self.numerical_columns)
440 - set(self.msg_columns)
441 - set(self.ip_columns)
442 )
443
444 for col in remaining_cols:
445 data = data.drop(col)
446
447 return data
448
449 @overload
450 def _get_features(
451 self: "PrincipleMLChecker", rule: idstools.rule.Rule, frame: Literal[True]
452 ) -> DataFrame:
453 pass
454
455 @overload
456 def _get_features(
457 self: "PrincipleMLChecker", rule: idstools.rule.Rule, frame: Literal[False]
458 ) -> Series:
459 pass
460
461 def _get_features_frame(self: "PrincipleMLChecker", features: Series) -> DataFrame:
462 features_frame = features.to_frame().transpose()
463
464 if self._dtypes is None:
465 return features_frame
466
467 for col, dtype in self._dtypes.items():
468 if features_frame.dtypes[col] != dtype:
469 features_frame[col] = features_frame[col].astype(dtype)
470
471 return features_frame
472
473 def _get_features(
474 self: "PrincipleMLChecker", rule: idstools.rule.Rule, frame: bool
475 ) -> Union[Series, DataFrame]:
476 features: Series = self._get_raw_features(rule)
477 features = self._preprocess_features(features)
478
479 features["custom.negated.count"] = rule["raw"].count(':!"')
480
481 if self._dtypes is None:
482 return features
483
484 for col, dtype in self._dtypes.items():
485 if col not in features:
486 if col.endswith(".count"):
487 features[col] = 0
488 elif col.endswith(".bool"):
489 features[col] = 0
490 elif col.endswith(".num"):
491 features[col] = -1
492 else:
493 _logger.error(
494 "Unsure how to handle missing feature %s of type %s",
495 col,
496 dtype,
497 )
498
499 features = features[list(self._dtypes.keys())] # type: ignore reportAssignmentType
500
501 if not frame:
502 return features
503
504 return self._get_features_frame(features)