Source code for suricata_check.checkers.principle.ml

  1"""`PrincipleMLChecker`."""
  2
  3import copy
  4import logging
  5import os
  6import pickle
  7from collections import Counter
  8from collections.abc import Iterable
  9from typing import Any, Literal, Optional, Union, overload
 10
 11import idstools.rule
 12import xgboost
 13from pandas import DataFrame, Series
 14from sklearn.metrics import f1_score, make_scorer, precision_score, recall_score
 15from sklearn.model_selection import (
 16    GridSearchCV,
 17    RepeatedStratifiedKFold,
 18    cross_val_score,
 19)
 20from sklearn.pipeline import Pipeline
 21
 22from suricata_check._version import SURICATA_CHECK_DIR
 23from suricata_check.checkers.interface.checker import CheckerInterface
 24from suricata_check.checkers.principle._utils import get_message
 25from suricata_check.utils.checker import get_rule_option, get_rule_suboptions
 26from suricata_check.utils.checker_typing import ISSUES_TYPE, Issue
 27
 28_PICKLE_PATH = os.path.join(SURICATA_CHECK_DIR, "data", "principle_ml_checker.pkl")
 29N_JOBS = 8
 30
 31
 32_logger = logging.getLogger(__name__)
 33
 34
 35COUNT_COLUMNS = (
 36    "flowbits.isset.count",
 37    "flowbits.isntoset.count",
 38    "flowint.isset.count",
 39    "flowint.isntoset.count",
 40    "xbits.isset.count",
 41    "xbits.uisnotset.count",
 42    "http.uri.count",
 43    "http.method.count",
 44    "dns.query.count",
 45    "content.count",
 46    "pcre.count",
 47    "startswith.count",
 48    "bsize.count",
 49    "depth.count",
 50    "urilen.count",
 51    "flow.from_server.count",
 52    "flow.to_server.count",
 53    "flow.from_client.count",
 54    "flow.to_client.count",
 55)
 56STRING_COLUMNS = ()
 57DROPDOWN_COLUMNS = (
 58    "proto",
 59    "threshold.type",
 60)
 61NUMERICAL_COLUMNS = ("threshold.count",)
 62SPLITTABLE_FEATURES = (
 63    "metadata",
 64    "flow",
 65    "threshold",
 66)
 67MSG_KEYWORDS = ("Suspicious", "CVE", "Vulnerability", "Response")
 68MSG_COLUMNS = ("msg.contains." + keyword for keyword in MSG_KEYWORDS)
 69IP_KEYWORDS = ("$HOME_NET", "$HTTP_SERVERS", "$EXTERNAL_NET", "any")
 70IP_COLUMNS = tuple(
 71    ["source_addr.contains." + keyword for keyword in IP_KEYWORDS]
 72    + ["dest_addr.contains." + keyword for keyword in IP_KEYWORDS]
 73)
 74
 75
 76PIPELINE = Pipeline(
 77    [
 78        (
 79            "classify",
 80            xgboost.XGBClassifier(),
 81        )
 82    ]
 83)
 84# https://shengyg.github.io/repository/machine%20learning/2017/02/25/Complete-Guide-to-Parameter-Tuning-xgboost.html
 85PARAM_GRID: list[dict] = [
 86    {
 87        # Fixed parameters for problem / desired complexity
 88        "classify__n_estimators": [1000],
 89        "classify__objective": ["binary:logistic"],
 90        ###
 91        # Parameters to optimize
 92        ## Learning rate
 93        "classify__eta": [0.01, 0.1, 0.3],
 94        ## Tree parameters
 95        "classify__subsample": [1.0],
 96        "classify__colsample_bytree": [0.25, 0.5, 0.75, 1.0],
 97        "classify__scale_pos_weight": [0.1, 0.25, 0.5, 1.0, 2.0, 4.0, 10.0],
 98        "classify__max_depth": [1, 3],
 99        "classify__min_child_weight": [1],
100        "classify__gamma": [0, 0.1],
101        ## Regularization
102        "classify__lambda": [0, 0.01, 0.1],
103        "classify__alpha": [0, 0.01, 0.1],
104    },
105]
106
107PRECISION_WEIGHT = 10
108SCORER = make_scorer(
109    lambda y, y_pred: (PRECISION_WEIGHT + 1)
110    / (
111        PRECISION_WEIGHT / (precision_score(y, y_pred, zero_division=1) + 1e-10)  # type: ignore reportArgumentType
112        + 1 / (recall_score(y, y_pred, zero_division=0) + 1e-10)  # type: ignore reportArgumentType
113    )
114)
115SPLITTER = RepeatedStratifiedKFold(n_splits=2, n_repeats=10)
116GRIDSEARCHCV = GridSearchCV(
117    PIPELINE, PARAM_GRID, cv=SPLITTER, scoring=SCORER, error_score="raise", n_jobs=N_JOBS, verbose=1  # type: ignore reportArgumentType
118)
119
120
[docs] 121class PrincipleMLChecker(CheckerInterface): 122 """The `PrincipleChecker` contains several checks based on the Ruling the Unruly paper and target specificity and coverage. 123 124 Codes Q000-Q009 report on non-adherence to rule design principles similar to Q000-Q009. 125 Differently, they are the result of machine learning analysis of the rules. 126 """ 127 128 count_columns = COUNT_COLUMNS 129 string_columns = STRING_COLUMNS 130 dropdown_columns = DROPDOWN_COLUMNS 131 numerical_columns = NUMERICAL_COLUMNS 132 splittable_features = SPLITTABLE_FEATURES 133 msg_keywords = MSG_KEYWORDS 134 msg_columns = MSG_COLUMNS 135 ip_keywords = IP_KEYWORDS 136 ip_columns = IP_COLUMNS 137 138 codes = { 139 "Q000": {"severity": logging.INFO}, 140 "Q001": {"severity": logging.INFO}, 141 "Q002": {"severity": logging.INFO}, 142 "Q003": {"severity": logging.INFO}, 143 "Q004": {"severity": logging.INFO}, 144 "Q005": {"severity": logging.INFO}, 145 } 146 147 enabled_by_default = ( 148 False # Since the checker is relatively slow, it is disabled by default 149 ) 150 151 _dtypes: Optional[dict[str, Any]] = None 152 _models: dict[str, Pipeline] = {} 153 154 def __new__( 155 cls: type["PrincipleMLChecker"], 156 filepath: Optional[str] = _PICKLE_PATH, 157 *args: tuple, 158 **kwargs: dict, 159 ) -> "PrincipleMLChecker": 160 """Returns a new or unpickled instance of the class.""" 161 if filepath: 162 if os.path.exists(filepath): 163 with open(filepath, "rb") as f: 164 inst = pickle.load(f) 165 166 # BEGIN LEGACY CODE 167 # CAN BE REMOVED AFTER TRAINING NEW PKL 168 if hasattr(inst, "models"): 169 inst._models = inst.models # noqa: SLF001 170 inst._dtypes = inst.dtypes # noqa: SLF001 171 # END LEGACY CODE 172 173 if not inst.__class__.__name__ == cls.__name__: 174 _logger.error("Unpickled object is not of type %s", cls) 175 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType 176 elif ( 177 not hasattr(inst, "_models") 178 or len( 179 inst._models # noqa: SLF001 type: ignore reportAttributeAccess 180 ) 181 == 0 182 ): 183 _logger.error("Unpickled object does not have trained models") 184 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType 185 else: 186 if "include" in kwargs: 187 inst.include = kwargs["include"] 188 _logger.info("Unpickled object with trained models successfully") 189 else: 190 _logger.warning("No model found for PrincipleMLChecker at %s", filepath) 191 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType 192 else: 193 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType 194 195 return inst 196
[docs] 197 def __getnewargs__( 198 self: "PrincipleMLChecker", 199 ) -> tuple: 200 """Returns the arguments to be passed to the __new__ method when unpickling.""" 201 return (None,)
202 203 def _check_rule( 204 self: "PrincipleMLChecker", 205 rule: idstools.rule.Rule, 206 ) -> ISSUES_TYPE: 207 issues: ISSUES_TYPE = [] 208 209 if len(self._models) == 0: 210 return issues 211 212 for code, model in self._models.items(): 213 if model.predict(self._get_features(rule, True))[0]: 214 issues.append( 215 Issue( 216 code=code, 217 message=get_message(code), 218 ) 219 ) 220 221 return issues 222
[docs] 223 def train( # noqa: C901 224 self: "PrincipleMLChecker", 225 df: DataFrame, 226 rule_col: str = "rule.rule", 227 principle_cols: dict[str, str] = { 228 "Q000": "labelled.no_proxy", 229 "Q001": "labelled.success", 230 "Q002": "labelled.thresholded", 231 "Q003": "labelled.exceptions", 232 "Q004": "labelled.generalized_match_content", 233 "Q005": "labelled.generalized_match_location", 234 }, 235 reuse_models: bool = False, 236 ) -> None: 237 """Train several models for the checker to detect issues in rules. 238 239 The checker class with trained models is stored in a pickle file (`_PICKLE_PATH`). 240 """ 241 self._dtypes = None 242 if not reuse_models: 243 self._models = {} 244 245 # Extract features and determine feature dtypes 246 X_train = self._get_train_df(df[rule_col]) # noqa: N806 247 248 for col in X_train.columns: 249 try: 250 X_train[col].var() 251 _logger.debug("Detected column: %s", col) 252 except: 253 _logger.error("Error with column %s", col) 254 _logger.error(X_train[col]) 255 raise 256 257 # # Drop zero variance columns 258 X_train = X_train.drop( # noqa: N806 259 X_train.columns[(X_train.fillna(-1337).var(axis=0) <= 0)].to_list(), # type: ignore reportAttributeAccessIssue 260 axis=1, 261 ) 262 263 # Drop columns with too few occurrences of possible values 264 for col in X_train.columns: 265 if ( 266 not col.endswith(".count") 267 and not col.endswith(".num") 268 and not col.endswith(".len") 269 ): 270 if X_train[col].value_counts().min() <= 1: 271 X_train = X_train.drop( # noqa: N806 272 [col], 273 axis=1, 274 ) 275 276 for col in X_train.columns: 277 try: 278 X_train[col].var() 279 _logger.info("Using column: %s", col) 280 except: 281 _logger.error("Error with column %s", col) 282 _logger.error(X_train[col]) 283 raise 284 285 # Store used features and their dtypes 286 self._dtypes = X_train.dtypes.to_dict() 287 _logger.debug(self._dtypes) 288 289 # Redo feature extraction now that FE parameters are set 290 X_train = self._get_train_df(df[rule_col]) # noqa: N806 291 292 _logger.info( 293 "Training model with features: [%s]", 294 ", ".join([str(x) for x in X_train.columns]), 295 ) 296 297 _logger.info(X_train) 298 299 for code, col in principle_cols.items(): 300 y_true = df[col].to_numpy() == 0 301 302 if not reuse_models or code not in self._models: 303 # Train new model with grid search to find optimal parameters 304 gridsearchcv: GridSearchCV = copy.deepcopy(GRIDSEARCHCV) 305 306 gridsearchcv.fit(X_train, y_true) 307 308 _logger.info("Code %s params: %s", code, gridsearchcv.best_params_) 309 _logger.info( 310 "Code %s Weighted F1-score: %s", code, gridsearchcv.best_score_ 311 ) 312 313 self._models[code] = gridsearchcv.best_estimator_ 314 315 precision = cross_val_score( 316 self._models[code], 317 X_train, 318 y_true, 319 scoring=make_scorer(precision_score, zero_division=0.0), 320 cv=SPLITTER, 321 n_jobs=N_JOBS, 322 ).mean() 323 recall = cross_val_score( 324 self._models[code], 325 X_train, 326 y_true, 327 scoring=make_scorer(recall_score, zero_division=0.0), 328 cv=SPLITTER, 329 n_jobs=N_JOBS, 330 ).mean() 331 f1 = cross_val_score( 332 self._models[code], 333 X_train, 334 y_true, 335 scoring=make_scorer(f1_score, zero_division=0.0), 336 cv=SPLITTER, 337 n_jobs=N_JOBS, 338 ).mean() 339 _logger.info("Code %s Precision score: %s", code, precision) 340 _logger.info("Code %s Recall score: %s", code, recall) 341 _logger.info("Code %s F1-score: %s", code, f1) 342 343 # Refit model with training data. 344 self._models[code].fit(X_train, y_true) 345 346 pickle.dump(self, open(_PICKLE_PATH, "wb"))
347 348 def _get_train_df(self: "PrincipleMLChecker", rules: Iterable[str]) -> DataFrame: 349 feature_vectors = [] 350 for rule in rules: 351 parsed_rule = idstools.rule.parse(rule) 352 assert parsed_rule is not None 353 feature_vectors.append(self._get_features(parsed_rule, False)) 354 355 return DataFrame(feature_vectors) 356 357 def _get_raw_features( # noqa: C901 358 self: "PrincipleMLChecker", rule: idstools.rule.Rule 359 ) -> Series: 360 d: dict[str, Optional[Union[str, int]]] = { 361 "proto": get_rule_option(rule, "proto") 362 } 363 364 options = rule["options"] 365 366 for option in options: 367 d[option["name"]] = option["value"] 368 369 counter = Counter([option["name"] for option in options]) 370 for option, count in counter.items(): 371 d[option + ".count"] = count 372 373 for option in options: 374 if option["name"] not in self.splittable_features: 375 continue 376 377 suboptions = [ 378 {"name": k, "value": v} 379 for k, v in get_rule_suboptions(rule, option["name"], warn=False) 380 ] 381 382 if len(suboptions) == 0: 383 continue 384 385 for suboption in suboptions: 386 d[option["name"] + "." + suboption["name"]] = suboption["value"] 387 388 counter = Counter([suboption["name"] for suboption in suboptions]) 389 for suboption, count in counter.items(): 390 d[option["name"] + "." + suboption + ".count"] = count 391 392 msg = get_rule_option(rule, "msg") 393 assert msg is not None 394 msg = msg.lower() 395 for col, keyword in zip(self.msg_columns, self.msg_keywords): 396 d[col] = keyword.lower() in msg 397 398 source_addr = get_rule_option(rule, "source_addr") 399 assert source_addr is not None 400 source_addr = source_addr.lower() 401 for keyword in self.ip_keywords: 402 col = "source_addr.contains." + keyword 403 d[col] = keyword.lower() in source_addr 404 405 dest_addr = get_rule_option(rule, "dest_addr") 406 assert dest_addr is not None 407 dest_addr = dest_addr.lower() 408 for keyword in self.ip_keywords: 409 col = "dest_addr.contains." + keyword 410 d[col] = keyword.lower() in dest_addr 411 412 return Series(d) 413 414 def _preprocess_features(self: "PrincipleMLChecker", data: Series) -> Series: 415 original_cols: set[str] = set(data.index) 416 417 for col in self.string_columns: 418 if col not in data: 419 continue 420 data[col + ".len"] = len(data[col]) 421 data = data.drop(col) 422 423 for col in self.dropdown_columns: 424 if col not in data: 425 continue 426 data[col + "." + data[col] + ".bool"] = 1 427 data = data.drop(col) 428 429 for col in self.numerical_columns: 430 if col not in data: 431 continue 432 data[col + ".num"] = float(data[col]) 433 data = data.drop(col) 434 435 remaining_cols = ( 436 original_cols 437 - set(self.count_columns) 438 - set(self.string_columns) 439 - set(self.dropdown_columns) 440 - set(self.numerical_columns) 441 - set(self.msg_columns) 442 - set(self.ip_columns) 443 ) 444 445 for col in remaining_cols: 446 data = data.drop(col) 447 448 return data 449 450 @overload 451 def _get_features( 452 self: "PrincipleMLChecker", rule: idstools.rule.Rule, frame: Literal[True] 453 ) -> DataFrame: 454 pass 455 456 @overload 457 def _get_features( 458 self: "PrincipleMLChecker", rule: idstools.rule.Rule, frame: Literal[False] 459 ) -> Series: 460 pass 461 462 def _get_features_frame(self: "PrincipleMLChecker", features: Series) -> DataFrame: 463 features_frame = features.to_frame().transpose() 464 465 if self._dtypes is None: 466 return features_frame 467 468 for col, dtype in self._dtypes.items(): 469 if features_frame.dtypes[col] != dtype: 470 features_frame[col] = features_frame[col].astype(dtype) 471 472 return features_frame 473 474 def _get_features( 475 self: "PrincipleMLChecker", rule: idstools.rule.Rule, frame: bool 476 ) -> Union[Series, DataFrame]: 477 features: Series = self._get_raw_features(rule) 478 features = self._preprocess_features(features) 479 480 features["custom.negated.count"] = rule["raw"].count(':!"') 481 482 if self._dtypes is None: 483 return features 484 485 for col, dtype in self._dtypes.items(): 486 if col not in features: 487 if col.endswith(".count"): 488 features[col] = 0 489 elif col.endswith(".bool"): 490 features[col] = 0 491 elif col.endswith(".num"): 492 features[col] = -1 493 else: 494 _logger.error( 495 "Unsure how to handle missing feature %s of type %s", 496 col, 497 dtype, 498 ) 499 500 features = features[list(self._dtypes.keys())] # type: ignore reportAssignmentType 501 502 if not frame: 503 return features 504 505 return self._get_features_frame(features)