You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

86 lines
2.5 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# Needed to use PSC build
import exec_anaconda
exec_anaconda.exec_anaconda_or_die()
# Needed to import libraries in /lib folder
import sys, os
import time
import uuid
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "lib"))
from splunklib.searchcommands import dispatch, GeneratingCommand, Configuration, Option, Boolean
import setup_logging
logger = setup_logging.get_logger()
def is_meme(spl):
from collections import Counter
split_str = spl.split()
if len(split_str) < 10:
return False
token_counts = Counter(split_str)
for token in token_counts:
if len(token) > 1 and token_counts[token] > len(split_str) // 4:
return True
return False
def postprocess(choices):
not_meme = [s for s in choices if not is_meme(s)]
spl_pred = not_meme[0] if len(not_meme) > 0 else choices[0]
return spl_pred
@Configuration()
class SPLGenCommand(GeneratingCommand):
"""
Generating command that yields predicted SPL queries based on a plain English description of a query.
"""
# Required plain English description of a query
prompt = Option(require=True)
explain = Option(default=False, require=False, validate=Boolean())
def generate(self):
"""
Yields records of predicted SPL queries to the Splunk processing pipeline.
"""
start_time = time.time()
logger.info("Starting generate.")
logger.info(f"UUID={str(uuid.uuid4())}")
# Needed to reduce overhead in "Parsing search"
from spl_gen.t5onnx.predict import Predictor
predictor = Predictor()
# Make predictions
prompt = self.prompt.lower()
if self.explain:
predictions = predictor.predict(prompt)
explain = predictions[0]
yield self.gen_record(
EXPLAIN=explain
)
else:
predictions = predictor.predict(prompt)
spl = postprocess(predictions)
if "index=*" in spl or "sourcetype=*" in spl:
self.write_warning("Using 'index=*' or 'sourcetype=*' can potentially be very expensive for your stack.")
if "| join" in spl:
self.write_warning("The join operation is usually very expensive.")
yield self.gen_record(
SPL=spl,
)
end_time = time.time()
logger.info(f"command=splgen, apply_time={round((end_time - start_time), 5)}")
logger.info("Exiting generate.")
dispatch(command_class=SPLGenCommand)