Source code for suricata_check.checkers.principle.ml

  1"""`PrincipleMLChecker`."""
  2
  3import copy
  4import logging
  5import os
  6import pickle
  7from collections import Counter
  8from collections.abc import Iterable
  9from typing import Any, Literal, Optional, Union, overload
 10
 11import idstools.rule
 12import xgboost
 13from pandas import DataFrame, Series
 14from sklearn.metrics import f1_score, make_scorer, precision_score, recall_score
 15from sklearn.model_selection import (
 16    GridSearchCV,
 17    RepeatedStratifiedKFold,
 18    cross_val_score,
 19)
 20from sklearn.pipeline import Pipeline
 21
 22from suricata_check._version import SURICATA_CHECK_DIR
 23from suricata_check.checkers.interface.checker import CheckerInterface
 24from suricata_check.checkers.principle._utils import get_message
 25from suricata_check.utils.checker import get_rule_option, get_rule_suboptions
 26from suricata_check.utils.checker_typing import ISSUES_TYPE, Issue
 27
 28_PICKLE_PATH = os.path.join(SURICATA_CHECK_DIR, "data", "principle_ml_checker.pkl")
 29N_JOBS = 8
 30
 31
 32_logger = logging.getLogger(__name__)
 33
 34
 35COUNT_COLUMNS = (
 36    "flowbits.isset.count",
 37    "flowbits.isntoset.count",
 38    "flowint.isset.count",
 39    "flowint.isntoset.count",
 40    "xbits.isset.count",
 41    "xbits.uisnotset.count",
 42    "http.uri.count",
 43    "http.method.count",
 44    "dns.query.count",
 45    "content.count",
 46    "pcre.count",
 47    "startswith.count",
 48    "bsize.count",
 49    "depth.count",
 50    "urilen.count",
 51    "flow.from_server.count",
 52    "flow.to_server.count",
 53    "flow.from_client.count",
 54    "flow.to_client.count",
 55)
 56STRING_COLUMNS = ()
 57DROPDOWN_COLUMNS = (
 58    "proto",
 59    "threshold.type",
 60)
 61NUMERICAL_COLUMNS = ("threshold.count",)
 62SPLITTABLE_FEATURES = (
 63    "metadata",
 64    "flow",
 65    "threshold",
 66)
 67MSG_KEYWORDS = ("Suspicious", "CVE", "Vulnerability", "Response")
 68MSG_COLUMNS = ("msg.contains." + keyword for keyword in MSG_KEYWORDS)
 69IP_KEYWORDS = ("$HOME_NET", "$HTTP_SERVERS", "$EXTERNAL_NET", "any")
 70IP_COLUMNS = tuple(
 71    ["source_addr.contains." + keyword for keyword in IP_KEYWORDS]
 72    + ["dest_addr.contains." + keyword for keyword in IP_KEYWORDS]
 73)
 74
 75
 76PIPELINE = Pipeline(
 77    [
 78        (
 79            "classify",
 80            xgboost.XGBClassifier(),
 81        )
 82    ]
 83)
 84# https://shengyg.github.io/repository/machine%20learning/2017/02/25/Complete-Guide-to-Parameter-Tuning-xgboost.html
 85PARAM_GRID: list[dict] = [
 86    {
 87        # Fixed parameters for problem / desired complexity
 88        "classify__n_estimators": [1000],
 89        "classify__objective": ["binary:logistic"],
 90        ###
 91        # Parameters to optimize
 92        ## Learning rate
 93        "classify__eta": [0.01, 0.1, 0.3],
 94        ## Tree parameters
 95        "classify__subsample": [1.0],
 96        "classify__colsample_bytree": [0.25, 0.5, 0.75, 1.0],
 97        "classify__scale_pos_weight": [0.1, 0.25, 0.5, 1.0, 2.0, 4.0, 10.0],
 98        "classify__max_depth": [1, 3],
 99        "classify__min_child_weight": [1],
100        "classify__gamma": [0, 0.1],
101        ## Regularization
102        "classify__lambda": [0, 0.01, 0.1],
103        "classify__alpha": [0, 0.01, 0.1],
104    },
105]
106
107PRECISION_WEIGHT = 10
108SCORER = make_scorer(
109    lambda y, y_pred: (PRECISION_WEIGHT + 1)
110    / (
111        PRECISION_WEIGHT / (precision_score(y, y_pred, zero_division=1) + 1e-10)  # type: ignore reportArgumentType
112        + 1 / (recall_score(y, y_pred, zero_division=0) + 1e-10)  # type: ignore reportArgumentType
113    )
114)
115SPLITTER = RepeatedStratifiedKFold(n_splits=2, n_repeats=10)
116GRIDSEARCHCV = GridSearchCV(
117    PIPELINE, PARAM_GRID, cv=SPLITTER, scoring=SCORER, error_score="raise", n_jobs=N_JOBS, verbose=1  # type: ignore reportArgumentType
118)
119
120
[docs] 121class PrincipleMLChecker(CheckerInterface): 122 """The `PrincipleChecker` contains several checks based on the Ruling the Unruly paper and target specificity and coverage. 123 124 Codes Q000-Q009 report on non-adherence to rule design principles similar to Q000-Q009. 125 Differently, they are the result of machine learning analysis of the rules. 126 """ 127 128 count_columns = COUNT_COLUMNS 129 string_columns = STRING_COLUMNS 130 dropdown_columns = DROPDOWN_COLUMNS 131 numerical_columns = NUMERICAL_COLUMNS 132 splittable_features = SPLITTABLE_FEATURES 133 msg_keywords = MSG_KEYWORDS 134 msg_columns = MSG_COLUMNS 135 ip_keywords = IP_KEYWORDS 136 ip_columns = IP_COLUMNS 137 138 codes = { 139 "Q000": {"severity": logging.INFO}, 140 "Q001": {"severity": logging.INFO}, 141 "Q002": {"severity": logging.INFO}, 142 "Q003": {"severity": logging.INFO}, 143 "Q004": {"severity": logging.INFO}, 144 "Q005": {"severity": logging.INFO}, 145 } 146 147 enabled_by_default = ( 148 False # Since the checker is relatively slow, it is disabled by default 149 ) 150 151 _dtypes: Optional[dict[str, Any]] = None 152 _models: dict[str, Pipeline] = {} 153 154 def __new__( 155 cls: type["PrincipleMLChecker"], 156 filepath: Optional[str] = _PICKLE_PATH, 157 *args: tuple, 158 **kwargs: dict, 159 ) -> "PrincipleMLChecker": 160 """Returns a new or unpickled instance of the class.""" 161 if filepath: 162 if os.path.exists(filepath): 163 with open(filepath, "rb") as f: 164 inst = pickle.load(f) 165 166 # BEGIN LEGACY CODE 167 # CAN BE REMOVED AFTER TRAINING NEW PKL 168 if hasattr(inst, "models"): 169 inst._models = inst.models # noqa: SLF001 170 inst._dtypes = inst.dtypes # noqa: SLF001 171 # END LEGACY CODE 172 173 if not inst.__class__.__name__ == cls.__name__: 174 _logger.error("Unpickled object is not of type %s", cls) 175 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType 176 elif not hasattr(inst, "_models") or len(inst._models) == 0: 177 _logger.error("Unpickled object does not have trained models") 178 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType 179 else: 180 if "include" in kwargs: 181 inst.include = kwargs["include"] 182 _logger.info("Unpickled object with trained models successfully") 183 else: 184 _logger.warning("No model found for PrincipleMLChecker at %s", filepath) 185 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType 186 else: 187 inst = super().__new__(cls, *args, **kwargs) # type: ignore reportargumentType 188 189 return inst 190
[docs] 191 def __getnewargs__( 192 self: "PrincipleMLChecker", 193 ) -> tuple: 194 """Returns the arguments to be passed to the __new__ method when unpickling.""" 195 return (None,)
196 197 def _check_rule( 198 self: "PrincipleMLChecker", 199 rule: idstools.rule.Rule, 200 ) -> ISSUES_TYPE: 201 issues: ISSUES_TYPE = [] 202 203 if len(self._models) == 0: 204 return issues 205 206 for code, model in self._models.items(): 207 if model.predict(self._get_features(rule, True))[0]: 208 issues.append( 209 Issue( 210 code=code, 211 message=get_message(code), 212 ) 213 ) 214 215 return issues 216
[docs] 217 def train( # noqa: C901 218 self: "PrincipleMLChecker", 219 df: DataFrame, 220 rule_col: str = "rule.rule", 221 principle_cols: dict[str, str] = { 222 "Q000": "labelled.no_proxy", 223 "Q001": "labelled.success", 224 "Q002": "labelled.thresholded", 225 "Q003": "labelled.exceptions", 226 "Q004": "labelled.generalized_match_content", 227 "Q005": "labelled.generalized_match_location", 228 }, 229 reuse_models: bool = False, 230 ) -> None: 231 """Train several models for the checker to detect issues in rules. 232 233 The checker class with trained models is stored in a pickle file (`_PICKLE_PATH`). 234 """ 235 self._dtypes = None 236 if not reuse_models: 237 self._models = {} 238 239 # Extract features and determine feature dtypes 240 X_train = self._get_train_df(df[rule_col]) # noqa: N806 241 242 for col in X_train.columns: 243 try: 244 X_train[col].var() 245 _logger.debug("Detected column: %s", col) 246 except: 247 _logger.error("Error with column %s", col) 248 _logger.error(X_train[col]) 249 raise 250 251 # # Drop zero variance columns 252 X_train = X_train.drop( # noqa: N806 253 X_train.columns[(X_train.fillna(-1337).var(axis=0) <= 0)].to_list(), # type: ignore reportAttributeAccessIssue 254 axis=1, 255 ) 256 257 # Drop columns with too few occurrences of possible values 258 for col in X_train.columns: 259 if ( 260 not col.endswith(".count") 261 and not col.endswith(".num") 262 and not col.endswith(".len") 263 ): 264 if X_train[col].value_counts().min() <= 1: 265 X_train = X_train.drop( # noqa: N806 266 [col], 267 axis=1, 268 ) 269 270 for col in X_train.columns: 271 try: 272 X_train[col].var() 273 _logger.info("Using column: %s", col) 274 except: 275 _logger.error("Error with column %s", col) 276 _logger.error(X_train[col]) 277 raise 278 279 # Store used features and their dtypes 280 self._dtypes = X_train.dtypes.to_dict() 281 _logger.debug(self._dtypes) 282 283 # Redo feature extraction now that FE parameters are set 284 X_train = self._get_train_df(df[rule_col]) # noqa: N806 285 286 _logger.info( 287 "Training model with features: [%s]", 288 ", ".join([str(x) for x in X_train.columns]), 289 ) 290 291 _logger.info(X_train) 292 293 for code, col in principle_cols.items(): 294 y_true = df[col].to_numpy() == 0 295 296 if not reuse_models or code not in self._models: 297 # Train new model with grid search to find optimal parameters 298 gridsearchcv: GridSearchCV = copy.deepcopy(GRIDSEARCHCV) 299 300 gridsearchcv.fit(X_train, y_true) 301 302 _logger.info("Code %s params: %s", code, gridsearchcv.best_params_) 303 _logger.info( 304 "Code %s Weighted F1-score: %s", code, gridsearchcv.best_score_ 305 ) 306 307 self._models[code] = gridsearchcv.best_estimator_ 308 309 precision = cross_val_score( 310 self._models[code], 311 X_train, 312 y_true, 313 scoring=make_scorer(precision_score, zero_division=0.0), 314 cv=SPLITTER, 315 n_jobs=N_JOBS, 316 ).mean() 317 recall = cross_val_score( 318 self._models[code], 319 X_train, 320 y_true, 321 scoring=make_scorer(recall_score, zero_division=0.0), 322 cv=SPLITTER, 323 n_jobs=N_JOBS, 324 ).mean() 325 f1 = cross_val_score( 326 self._models[code], 327 X_train, 328 y_true, 329 scoring=make_scorer(f1_score, zero_division=0.0), 330 cv=SPLITTER, 331 n_jobs=N_JOBS, 332 ).mean() 333 _logger.info("Code %s Precision score: %s", code, precision) 334 _logger.info("Code %s Recall score: %s", code, recall) 335 _logger.info("Code %s F1-score: %s", code, f1) 336 337 # Refit model with training data. 338 self._models[code].fit(X_train, y_true) 339 340 pickle.dump(self, open(_PICKLE_PATH, "wb"))
341 342 def _get_train_df(self: "PrincipleMLChecker", rules: Iterable[str]) -> DataFrame: 343 feature_vectors = [] 344 for rule in rules: 345 parsed_rule = idstools.rule.parse(rule) 346 assert parsed_rule is not None 347 feature_vectors.append(self._get_features(parsed_rule, False)) 348 349 return DataFrame(feature_vectors) 350 351 def _get_raw_features( # noqa: C901 352 self: "PrincipleMLChecker", rule: idstools.rule.Rule 353 ) -> Series: 354 d: dict[str, Optional[Union[str, int]]] = { 355 "proto": get_rule_option(rule, "proto") 356 } 357 358 options = rule["options"] 359 360 for option in options: 361 d[option["name"]] = option["value"] 362 363 counter = Counter([option["name"] for option in options]) 364 for option, count in counter.items(): 365 d[option + ".count"] = count 366 367 for option in options: 368 if option["name"] not in self.splittable_features: 369 continue 370 371 suboptions = [ 372 {"name": k, "value": v} 373 for k, v in get_rule_suboptions(rule, option["name"], warn=False) 374 ] 375 376 if len(suboptions) == 0: 377 continue 378 379 for suboption in suboptions: 380 d[option["name"] + "." + suboption["name"]] = suboption["value"] 381 382 counter = Counter([suboption["name"] for suboption in suboptions]) 383 for suboption, count in counter.items(): 384 d[option["name"] + "." + suboption + ".count"] = count 385 386 msg = get_rule_option(rule, "msg") 387 assert msg is not None 388 msg = msg.lower() 389 for col, keyword in zip(self.msg_columns, self.msg_keywords): 390 d[col] = keyword.lower() in msg 391 392 source_addr = get_rule_option(rule, "source_addr") 393 assert source_addr is not None 394 source_addr = source_addr.lower() 395 for keyword in self.ip_keywords: 396 col = "source_addr.contains." + keyword 397 d[col] = keyword.lower() in source_addr 398 399 dest_addr = get_rule_option(rule, "dest_addr") 400 assert dest_addr is not None 401 dest_addr = dest_addr.lower() 402 for keyword in self.ip_keywords: 403 col = "dest_addr.contains." + keyword 404 d[col] = keyword.lower() in dest_addr 405 406 return Series(d) 407 408 def _preprocess_features(self: "PrincipleMLChecker", data: Series) -> Series: 409 original_cols: set[str] = set(data.index) 410 411 for col in self.string_columns: 412 if col not in data: 413 continue 414 data[col + ".len"] = len(data[col]) 415 data = data.drop(col) 416 417 for col in self.dropdown_columns: 418 if col not in data: 419 continue 420 data[col + "." + data[col] + ".bool"] = 1 421 data = data.drop(col) 422 423 for col in self.numerical_columns: 424 if col not in data: 425 continue 426 data[col + ".num"] = float(data[col]) 427 data = data.drop(col) 428 429 remaining_cols = ( 430 original_cols 431 - set(self.count_columns) 432 - set(self.string_columns) 433 - set(self.dropdown_columns) 434 - set(self.numerical_columns) 435 - set(self.msg_columns) 436 - set(self.ip_columns) 437 ) 438 439 for col in remaining_cols: 440 data = data.drop(col) 441 442 return data 443 444 @overload 445 def _get_features( 446 self: "PrincipleMLChecker", rule: idstools.rule.Rule, frame: Literal[True] 447 ) -> DataFrame: 448 pass 449 450 @overload 451 def _get_features( 452 self: "PrincipleMLChecker", rule: idstools.rule.Rule, frame: Literal[False] 453 ) -> Series: 454 pass 455 456 def _get_features_frame(self: "PrincipleMLChecker", features: Series) -> DataFrame: 457 features_frame = features.to_frame().transpose() 458 459 if self._dtypes is None: 460 return features_frame 461 462 for col, dtype in self._dtypes.items(): 463 if features_frame.dtypes[col] != dtype: 464 features_frame[col] = features_frame[col].astype(dtype) 465 466 return features_frame 467 468 def _get_features( 469 self: "PrincipleMLChecker", rule: idstools.rule.Rule, frame: bool 470 ) -> Union[Series, DataFrame]: 471 features: Series = self._get_raw_features(rule) 472 features = self._preprocess_features(features) 473 474 features["custom.negated.count"] = rule["raw"].count(':!"') 475 476 if self._dtypes is None: 477 return features 478 479 for col, dtype in self._dtypes.items(): 480 if col not in features: 481 if col.endswith(".count"): 482 features[col] = 0 483 elif col.endswith(".bool"): 484 features[col] = 0 485 elif col.endswith(".num"): 486 features[col] = -1 487 else: 488 _logger.error( 489 "Unsure how to handle missing feature %s of type %s", 490 col, 491 dtype, 492 ) 493 494 features = features[list(self._dtypes.keys())] # type: ignore reportAssignmentType 495 496 if not frame: 497 return features 498 499 return self._get_features_frame(features)