Source code for suricata_check.checkers.principle.ml

  1"""`PrincipleMLChecker`."""
  2
  3import copy
  4import logging
  5import os
  6import pickle
  7from collections import Counter
  8from collections.abc import Iterable
  9from typing import Any, Literal, Optional, Union, overload
 10
 11import idstools.rule
 12import xgboost
 13from pandas import DataFrame, Series
 14from sklearn.metrics import f1_score, make_scorer, precision_score, recall_score
 15from sklearn.model_selection import (
 16    GridSearchCV,
 17    RepeatedStratifiedKFold,
 18    cross_val_score,
 19)
 20from sklearn.pipeline import Pipeline
 21
 22from suricata_check._version import SURICATA_CHECK_DIR
 23from suricata_check.checkers.interface.checker import CheckerInterface
 24from suricata_check.checkers.principle._utils import get_message
 25from suricata_check.utils.checker import get_rule_option, get_rule_suboptions
 26from suricata_check.utils.checker_typing import ISSUES_TYPE, Issue
 27
 28_PICKLE_PATH = os.path.join(SURICATA_CHECK_DIR, "data", "principle_ml_checker.pkl")
 29N_JOBS = 8
 30
 31
 32_logger = logging.getLogger(__name__)
 33
 34
 35COUNT_COLUMNS = (
 36    "flowbits.isset.count",
 37    "flowbits.isntoset.count",
 38    "flowint.isset.count",
 39    "flowint.isntoset.count",
 40    "xbits.isset.count",
 41    "xbits.uisnotset.count",
 42    "http.uri.count",
 43    "http.method.count",
 44    "dns.query.count",
 45    "content.count",
 46    "pcre.count",
 47    "startswith.count",
 48    "bsize.count",
 49    "depth.count",
 50    "urilen.count",
 51    "flow.from_server.count",
 52    "flow.to_server.count",
 53    "flow.from_client.count",
 54    "flow.to_client.count",
 55)
 56STRING_COLUMNS = ()
 57DROPDOWN_COLUMNS = (
 58    "proto",
 59    "threshold.type",
 60)
 61NUMERICAL_COLUMNS = ("threshold.count",)
 62SPLITTABLE_FEATURES = (
 63    "metadata",
 64    "flow",
 65    "threshold",
 66)
 67MSG_KEYWORDS = ("Suspicious", "CVE", "Vulnerability", "Response")
 68MSG_COLUMNS = ("msg.contains." + keyword for keyword in MSG_KEYWORDS)
 69IP_KEYWORDS = ("$HOME_NET", "$HTTP_SERVERS", "$EXTERNAL_NET", "any")
 70IP_COLUMNS = tuple(
 71    ["source_addr.contains." + keyword for keyword in IP_KEYWORDS]
 72    + ["dest_addr.contains." + keyword for keyword in IP_KEYWORDS]
 73)
 74
 75
 76PIPELINE = Pipeline(
 77    [
 78        (
 79            "classify",
 80            xgboost.XGBClassifier(),
 81        )
 82    ]
 83)
 84# https://shengyg.github.io/repository/machine%20learning/2017/02/25/Complete-Guide-to-Parameter-Tuning-xgboost.html
 85PARAM_GRID: list[dict] = [
 86    {
 87        # Fixed parameters for problem / desired complexity
 88        "classify__n_estimators": [1000],
 89        "classify__objective": ["binary:logistic"],
 90        ###
 91        # Parameters to optimize
 92        ## Learning rate
 93        "classify__eta": [0.01, 0.1, 0.3],
 94        ## Tree parameters
 95        "classify__subsample": [1.0],
 96        "classify__colsample_bytree": [0.25, 0.5, 0.75, 1.0],
 97        "classify__scale_pos_weight": [0.1, 0.25, 0.5, 1.0, 2.0, 4.0, 10.0],
 98        "classify__max_depth": [1, 3],
 99        "classify__min_child_weight": [1],
100        "classify__gamma": [0, 0.1],
101        ## Regularization
102        "classify__lambda": [0, 0.01, 0.1],
103        "classify__alpha": [0, 0.01, 0.1],
104    },
105]
106
107PRECISION_WEIGHT = 10
108SCORER = make_scorer(
109    lambda y, y_pred: (PRECISION_WEIGHT + 1)
110    / (
111        PRECISION_WEIGHT / (precision_score(y, y_pred, zero_division=1) + 1e-10)  # type: ignore reportArgumentType
112        + 1 / (recall_score(y, y_pred, zero_division=0) + 1e-10)  # type: ignore reportArgumentType
113    )
114)
115SPLITTER = RepeatedStratifiedKFold(n_splits=2, n_repeats=10)
116GRIDSEARCHCV = GridSearchCV(
117    PIPELINE, PARAM_GRID, cv=SPLITTER, scoring=SCORER, error_score="raise", n_jobs=N_JOBS, verbose=1  # type: ignore reportArgumentType
118)
119
120
[docs] 121class PrincipleMLChecker(CheckerInterface): 122 """The `PrincipleChecker` contains several checks based on the Ruling the Unruly paper and target specificity and coverage. 123 124 Codes Q000-Q009 report on non-adherence to rule design principles similar to Q000-Q009. 125 Differently, they are the result of machine learning analysis of the rules. 126 """ 127 128 count_columns = COUNT_COLUMNS 129 string_columns = STRING_COLUMNS 130 dropdown_columns = DROPDOWN_COLUMNS 131 numerical_columns = NUMERICAL_COLUMNS 132 splittable_features = SPLITTABLE_FEATURES 133 msg_keywords = MSG_KEYWORDS 134 msg_columns = MSG_COLUMNS 135 ip_keywords = IP_KEYWORDS 136 ip_columns = IP_COLUMNS 137 138 codes = { 139 "Q000": {"severity": logging.INFO}, 140 "Q001": {"severity": logging.INFO}, 141 "Q002": {"severity": logging.INFO}, 142 "Q003": {"severity": logging.INFO}, 143 "Q004": {"severity": logging.INFO}, 144 "Q005": {"severity": logging.INFO}, 145 } 146 147 enabled_by_default = ( 148 False # Since the checker is relatively slow, it is disabled by default 149 ) 150 151 _dtypes: Optional[dict[str, Any]] = None 152 _models: dict[str, Pipeline] = {} 153 154 def __new__( 155 cls: type["PrincipleMLChecker"], 156 filepath: Optional[str] = _PICKLE_PATH, 157 *args: tuple, 158 **kwargs: dict, 159 ) -> "PrincipleMLChecker": 160 """Returns a new or unpickled instance of the class.""" 161 if filepath: 162 if os.path.exists(filepath): 163 with open(filepath, "rb") as f: 164 inst = pickle.load(f) 165 166 # BEGIN LEGACY CODE 167 # CAN BE REMOVED AFTER TRAINING NEW PKL 168 if hasattr(inst, "models"): 169 inst._models = inst.models # noqa: SLF001 170 inst._dtypes = inst.dtypes # noqa: SLF001 171 # END LEGACY CODE 172 173 if not inst.__class__.__name__ == cls.__name__: 174 _logger.error("Unpickled object is not of type %s", cls) 175 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType 176 elif ( 177 not hasattr(inst, "_models") 178 or len( 179 inst._models # noqa: SLF001 type: ignore reportAttributeAccess 180 ) 181 == 0 182 ): 183 _logger.error("Unpickled object does not have trained models") 184 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType 185 else: 186 if "include" in kwargs: 187 inst.include = kwargs["include"] 188 _logger.info("Unpickled object with trained models successfully") 189 else: 190 _logger.warning("No model found for PrincipleMLChecker at %s", filepath) 191 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType 192 else: 193 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType 194 195 return inst 196
[docs] 197 def __getnewargs__( 198 self: "PrincipleMLChecker", 199 ) -> tuple: 200 """Returns the arguments to be passed to the __new__ method when unpickling.""" 201 return (None,)
202 203 def _check_rule( 204 self: "PrincipleMLChecker", 205 rule: idstools.rule.Rule, 206 ) -> ISSUES_TYPE: 207 issues: ISSUES_TYPE = [] 208 209 if len(self._models) == 0: 210 return issues 211 212 for code, model in self._models.items(): 213 if model.predict(self._get_features(rule, True))[0]: 214 issues.append( 215 Issue( 216 code=code, 217 message=get_message(code), 218 ) 219 ) 220 221 return issues 222
[docs] 223 def train( # noqa: C901 224 self: "PrincipleMLChecker", 225 df: DataFrame, 226 rule_col: str = "rule.rule", 227 principle_cols: dict[str, str] = { 228 "Q000": "labelled.no_proxy", 229 "Q001": "labelled.success", 230 "Q002": "labelled.thresholded", 231 "Q003": "labelled.exceptions", 232 "Q004": "labelled.generalized_match_content", 233 "Q005": "labelled.generalized_match_location", 234 }, 235 reuse_models: bool = False, 236 ) -> None: 237 """Train several models for the checker to detect issues in rules. 238 239 The checker class with trained models is stored in a pickle file (`_PICKLE_PATH`). 240 """ 241 self._dtypes = None 242 if not reuse_models: 243 self._models = {} 244 245 # Extract features and determine feature dtypes 246 X_train = self._get_train_df(df[rule_col]) # noqa: N806 247 248 for col in X_train.columns: 249 try: 250 X_train[col].var() 251 _logger.debug("Detected column: %s", col) 252 except: 253 _logger.error("Error with column %s", col) 254 _logger.error(X_train[col]) 255 raise 256 257 # # Drop zero variance columns 258 X_train = X_train.drop( # noqa: N806 259 X_train.columns[(X_train.fillna(-1337).var(axis=0) <= 0)].to_list(), # type: ignore reportAttributeAccessIssue 260 axis=1, 261 ) 262 263 # Drop columns with too few occurrences of possible values 264 for col in X_train.columns: 265 if ( 266 not col.endswith(".count") 267 and not col.endswith(".num") 268 and not col.endswith(".len") 269 ): 270 if X_train[col].value_counts().min() <= 1: 271 X_train = X_train.drop( # noqa: N806 272 [col], 273 axis=1, 274 ) 275 276 for col in X_train.columns: 277 try: 278 X_train[col].var() 279 _logger.info("Using column: %s", col) 280 except: 281 _logger.error("Error with column %s", col) 282 _logger.error(X_train[col]) 283 raise 284 285 # Store used features and their dtypes 286 self._dtypes = X_train.dtypes.to_dict() 287 _logger.debug(self._dtypes) 288 289 # Redo feature extraction now that FE parameters are set 290 X_train = self._get_train_df(df[rule_col]) # noqa: N806 291 292 _logger.info( 293 "Training model with features: [%s]", 294 ", ".join([str(x) for x in X_train.columns]), 295 ) 296 297 _logger.info(X_train) 298 299 for code, col in principle_cols.items(): 300 y_true = df[col].to_numpy() == 0 301 302 if not reuse_models or code not in self._models: 303 # Train new model with grid search to find optimal parameters 304 gridsearchcv: GridSearchCV = copy.deepcopy(GRIDSEARCHCV) 305 306 gridsearchcv.fit(X_train, y_true) 307 308 _logger.info("Code %s params: %s", code, gridsearchcv.best_params_) 309 _logger.info( 310 "Code %s Weighted F1-score: %s", code, gridsearchcv.best_score_ 311 ) 312 313 self._models[code] = gridsearchcv.best_estimator_ 314 315 precision = cross_val_score( 316 self._models[code], 317 X_train, 318 y_true, 319 scoring=make_scorer(precision_score, zero_division=0.0), 320 cv=SPLITTER, 321 n_jobs=N_JOBS, 322 ).mean() 323 recall = cross_val_score( 324 self._models[code], 325 X_train, 326 y_true, 327 scoring=make_scorer(recall_score, zero_division=0.0), 328 cv=SPLITTER, 329 n_jobs=N_JOBS, 330 ).mean() 331 f1 = cross_val_score( 332 self._models[code], 333 X_train, 334 y_true, 335 scoring=make_scorer(f1_score, zero_division=0.0), 336 cv=SPLITTER, 337 n_jobs=N_JOBS, 338 ).mean() 339 _logger.info("Code %s Precision score: %s", code, precision) 340 _logger.info("Code %s Recall score: %s", code, recall) 341 _logger.info("Code %s F1-score: %s", code, f1) 342 343 # Refit model with training data. 344 self._models[code].fit(X_train, y_true) 345 346 pickle.dump(self, open(_PICKLE_PATH, "wb"))
347 348 def _get_train_df(self: "PrincipleMLChecker", rules: Iterable[str]) -> DataFrame: 349 feature_vectors = [] 350 for rule in rules: 351 parsed_rule = idstools.rule.parse(rule) 352 feature_vectors.append(self._get_features(parsed_rule, False)) 353 354 return DataFrame(feature_vectors) 355 356 def _get_raw_features( # noqa: C901 357 self: "PrincipleMLChecker", rule: idstools.rule.Rule 358 ) -> Series: 359 d: dict[str, Optional[Union[str, int]]] = { 360 "proto": get_rule_option(rule, "proto") 361 } 362 363 options = rule["options"] 364 365 for option in options: 366 d[option["name"]] = option["value"] 367 368 counter = Counter([option["name"] for option in options]) 369 for option, count in counter.items(): 370 d[option + ".count"] = count 371 372 for option in options: 373 if option["name"] not in self.splittable_features: 374 continue 375 376 suboptions = [ 377 {"name": k, "value": v} 378 for k, v in get_rule_suboptions(rule, option["name"], warn=False) 379 ] 380 381 if len(suboptions) == 0: 382 continue 383 384 for suboption in suboptions: 385 d[option["name"] + "." + suboption["name"]] = suboption["value"] 386 387 counter = Counter([suboption["name"] for suboption in suboptions]) 388 for suboption, count in counter.items(): 389 d[option["name"] + "." + suboption + ".count"] = count 390 391 msg = get_rule_option(rule, "msg") 392 assert msg is not None 393 msg = msg.lower() 394 for col, keyword in zip(self.msg_columns, self.msg_keywords): 395 d[col] = keyword.lower() in msg 396 397 source_addr = get_rule_option(rule, "source_addr") 398 assert source_addr is not None 399 source_addr = source_addr.lower() 400 for keyword in self.ip_keywords: 401 col = "source_addr.contains." + keyword 402 d[col] = keyword.lower() in source_addr 403 404 dest_addr = get_rule_option(rule, "dest_addr") 405 assert dest_addr is not None 406 dest_addr = dest_addr.lower() 407 for keyword in self.ip_keywords: 408 col = "dest_addr.contains." + keyword 409 d[col] = keyword.lower() in dest_addr 410 411 return Series(d) 412 413 def _preprocess_features(self: "PrincipleMLChecker", data: Series) -> Series: 414 original_cols: set[str] = set(data.index) 415 416 for col in self.string_columns: 417 if col not in data: 418 continue 419 data[col + ".len"] = len(data[col]) 420 data = data.drop(col) 421 422 for col in self.dropdown_columns: 423 if col not in data: 424 continue 425 data[col + "." + data[col] + ".bool"] = 1 426 data = data.drop(col) 427 428 for col in self.numerical_columns: 429 if col not in data: 430 continue 431 data[col + ".num"] = float(data[col]) # type: ignore reportArgumentType 432 data = data.drop(col) 433 434 remaining_cols = ( 435 original_cols 436 - set(self.count_columns) 437 - set(self.string_columns) 438 - set(self.dropdown_columns) 439 - set(self.numerical_columns) 440 - set(self.msg_columns) 441 - set(self.ip_columns) 442 ) 443 444 for col in remaining_cols: 445 data = data.drop(col) 446 447 return data 448 449 @overload 450 def _get_features( 451 self: "PrincipleMLChecker", rule: idstools.rule.Rule, frame: Literal[True] 452 ) -> DataFrame: 453 pass 454 455 @overload 456 def _get_features( 457 self: "PrincipleMLChecker", rule: idstools.rule.Rule, frame: Literal[False] 458 ) -> Series: 459 pass 460 461 def _get_features_frame(self: "PrincipleMLChecker", features: Series) -> DataFrame: 462 features_frame = features.to_frame().transpose() 463 464 if self._dtypes is None: 465 return features_frame 466 467 for col, dtype in self._dtypes.items(): 468 if features_frame.dtypes[col] != dtype: 469 features_frame[col] = features_frame[col].astype(dtype) 470 471 return features_frame 472 473 def _get_features( 474 self: "PrincipleMLChecker", rule: idstools.rule.Rule, frame: bool 475 ) -> Union[Series, DataFrame]: 476 features: Series = self._get_raw_features(rule) 477 features = self._preprocess_features(features) 478 479 features["custom.negated.count"] = rule["raw"].count(':!"') 480 481 if self._dtypes is None: 482 return features 483 484 for col, dtype in self._dtypes.items(): 485 if col not in features: 486 if col.endswith(".count"): 487 features[col] = 0 488 elif col.endswith(".bool"): 489 features[col] = 0 490 elif col.endswith(".num"): 491 features[col] = -1 492 else: 493 _logger.error( 494 "Unsure how to handle missing feature %s of type %s", 495 col, 496 dtype, 497 ) 498 499 features = features[list(self._dtypes.keys())] # type: ignore reportAssignmentType 500 501 if not frame: 502 return features 503 504 return self._get_features_frame(features)