You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

289 lines
10 KiB

"""
(C) 2019 Splunk Inc. All rights reserved.
"""
import sys
import os
import uuid
from abc import ABCMeta, abstractmethod
import cloudgateway.private.util.sdk_mode
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'lib'))
import base64
import requests
from cloudgateway.private.websocket.cloudgateway_connector import CloudgatewayConnector
from cloudgateway.private.encryption.encryption_handler import decrypt_session_token
from cloudgateway.private.messages.send import build_encrypted_payload
from cloudgateway.private.util.splunk_auth_header import SplunkAuthHeader
from cloudgateway.private.util.logger import DummyLogger
from cloudgateway.private.util import constants
from cloudgateway.private.registration.util import sb_message_endpoint, sb_auth_header, requests_ssl_context
from cloudgateway.private.util.config import SplunkConfig
from cloudgateway.private.websocket.parent_process_monitor import ParentProcessMonitor
from spacebridge_protocol import http_pb2
from enum import Enum
if sys.version_info >= (3,0):
from cloudgateway.private.asyncio.websocket.aio_parent_process_monitor import AioParentProcessMonitor
try:
# conditional import which will only work if using splunk's python
if sys.version_info < (3,0):
from cloudgateway.splunk.twisted.cluster.cluster_monitor import ClusterMonitor
else:
from cloudgateway.splunk.asyncio.cluster.cluster_monitor import ClusterMonitor
except:
pass
class WebsocketMode(Enum):
"""
Enum for supported modes of initiating websocket connection. Async means the websocket is run via a single threaded
event loop and callbacks are expected to be non-blocking and return deferred objects. In threaded mode, callbacks
can be blocking and each message received will be delegated to a thread from a pool.
"""
ASYNC = 0
THREADED = 1
class ServerResponse(object):
"""
Class to encapsulate response message to be sent back to either the client or Cloudgateway
"""
@staticmethod
def create_rid():
"""
Helper method to generate a unique guid
:return:
"""
return str(uuid.uuid4())
def __init__(self, payload, request_id=None):
self.payload = payload
if request_id:
self.request_id = request_id
else:
self.request_id = ServerResponse.create_rid()
def __repr__(self):
return str(
{
'payload': self.payload,
'request_id': self.request_id
}
)
class CloudGatewayWsClient(object):
def __init__(self, encryption_context, message_handler, logger=None, mode=WebsocketMode.THREADED,
config=SplunkConfig(), shard_id=None, websocket_context=None, key_bundle=None):
"""
Args:
encryption_context: [EncryptionContext] Can be a regular EncryptionContext or a subclass such
as SplunkEncryptionContext depending on whether you want to run in standalone mode or not.
handler can call Splunk APIs if needed
message_handler: [AbstractMessageHandler] interface specifying how to handle messages from Cloudgateway
logger: Optional logger parameter for logging purposes
mode: [WebsocketMode] Enum specifying either threaded mode or async mode. Defaults to Threaded mode
config: Optional [CloudgatewaySdkConfig] configuration class
"""
self.encryption_context = encryption_context
self.logger = logger or DummyLogger()
self.logger.info(encryption_context)
self.message_handler = message_handler
self.mode = mode
self.config = config
self.shard_id = shard_id
self.key_bundle = key_bundle
if self.mode == WebsocketMode.THREADED:
websocket_mode = constants.THREADED_MODE
elif self.mode == WebsocketMode.ASYNC:
websocket_mode = constants.ASYNC_MODE
else:
raise ValueError("Unsupported websocket mode")
if self.encryption_context.mode == cloudgateway.private.util.sdk_mode.SdkMode.SPLUNK:
session_key = self.encryption_context.session_key
if sys.version_info < (3,0):
parent_process_monitor = ParentProcessMonitor()
else:
parent_process_monitor = AioParentProcessMonitor()
cluster_monitor = None
if self.shard_id is None:
cluster_monitor = ClusterMonitor(self.logger)
else:
session_key = ""
# parent_process_monitor = ParentProcessMonitor() # For testing only, remove
parent_process_monitor = None
cluster_monitor = None
self.connector = CloudgatewayConnector(self.message_handler,
self.encryption_context,
session_key,
parent_process_monitor,
cluster_monitor,
self.logger,
self.config,
mode=websocket_mode,
shard_id=self.shard_id,
websocket_context=websocket_context,
key_bundle=self.key_bundle
)
def connect(self, threadpool_size=None):
"""
Initiate websocket connection to cloudgateway
Args:
threadpool_size: [Integer] Size of threadpool to use. Only applies in Threaded Mode.
Returns: None
"""
self.connector.connect(threadpool_size)
def send(self, device_id, payload, request_id=None):
"""Send a message to a particular device using Cloud Gateway
Args:
device_id ([binary]): id of the device to send the message to
payload ([string]): Message to be sent. Can be any format you want (json, serialized proto, etc.)
request_id (str, optional): [description]. Defaults to "123".
Returns:
[requests.response]: response returned by requests object
"""
if not request_id:
request_id = ServerResponse.create_rid()
send_message_request = http_pb2.SendMessageRequest()
recipient_info = self.message_handler.fetch_device_info(device_id)
build_encrypted_payload(recipient_info,
self.encryption_context,
payload,
request_id,
self.logger,
signed_envelope=send_message_request.signedEnvelope)
# self.connector.factory.protocol.sendMessage(send_message_request.SerializeToString())
spacebridge_header = {'Authorization': sb_auth_header(self.encryption_context)}
with requests_ssl_context(self.key_bundle) as cert:
return requests.post(sb_message_endpoint(self.config),
headers=spacebridge_header,
data=send_message_request.SerializeToString(),
cert=cert.name
)
class AbstractWebsocketContext(object):
"""
Optional context class if you want finer grain control over behaviour of websocket
"""
__metaclass__ = ABCMeta
RETRY_INTERVAL_SECONDS = 2
@abstractmethod
def on_open(self, protocol):
pass
@abstractmethod
def on_ping(self, payload, protocol):
pass
@abstractmethod
def on_pong(self, payload, protocol):
pass
@abstractmethod
def on_close(self, wasClean, code, reason, protocol):
pass
class AbstractMessageHandler(object):
"""
Used to delegate methods necessary when a message is received from Cloudgateway.
"""
__metaclass__ = ABCMeta
def __init__(self, encryption_context):
"""
If you override the constructor, make sure to call this constructor, the session key is necessary to construct
the encryption context necessary for encrypting and decrypting messages
Args:
encryption_context: [EncryptionContext] object which is necessary for message handler to decrypt incoming
messages
"""
self.system_auth_header = ""
self.encryption_context = encryption_context
if hasattr(encryption_context, 'session_key'):
self.system_auth_header = SplunkAuthHeader(encryption_context.session_key)
@abstractmethod
def handle_application_message(self, msg, sender, request_id):
"""
Args:
msg: message payload sent by a client device
sender: byte array of device sending message
request_id: string representing unique identifier of message sent by the sender
Returns: ServerResponse or List[ServerResponse] which represent payloads to be sent back to the client device.
"""
raise NotImplementedError
@abstractmethod
def handle_cloudgateway_message(self, msg):
"""
Args:
msg: message payload strinng sent by cloud gatewayy
Returns: ServerResponse or List[ServerResponse]
"""
raise NotImplementedError
@abstractmethod
def fetch_device_info(self, device_id):
"""
Given device id, fetch DeviceInfo object associated to that device
Args:
device_id: byte array representing device id
Returns: DeviceInfo object
"""
raise NotImplementedError
def decrypt_session_token(self, encrypted_session_token):
"""
Args:
encrypted_session_token: An encrypted session token string
Returns: Decrypted session tokenn string
"""
public_key = self.encryption_context.encrypt_public_key()
private_key = self.encryption_context.encrypt_private_key()
raw_token = base64.b64decode(encrypted_session_token)
return decrypt_session_token(self.encryption_context.sodium_client, raw_token, public_key, private_key)