You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
289 lines
10 KiB
289 lines
10 KiB
"""
|
|
(C) 2019 Splunk Inc. All rights reserved.
|
|
"""
|
|
import sys
|
|
import os
|
|
import uuid
|
|
from abc import ABCMeta, abstractmethod
|
|
|
|
import cloudgateway.private.util.sdk_mode
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'lib'))
|
|
|
|
import base64
|
|
import requests
|
|
from cloudgateway.private.websocket.cloudgateway_connector import CloudgatewayConnector
|
|
from cloudgateway.private.encryption.encryption_handler import decrypt_session_token
|
|
from cloudgateway.private.messages.send import build_encrypted_payload
|
|
from cloudgateway.private.util.splunk_auth_header import SplunkAuthHeader
|
|
from cloudgateway.private.util.logger import DummyLogger
|
|
from cloudgateway.private.util import constants
|
|
from cloudgateway.private.registration.util import sb_message_endpoint, sb_auth_header, requests_ssl_context
|
|
from cloudgateway.private.util.config import SplunkConfig
|
|
from cloudgateway.private.websocket.parent_process_monitor import ParentProcessMonitor
|
|
from spacebridge_protocol import http_pb2
|
|
from enum import Enum
|
|
|
|
if sys.version_info >= (3,0):
|
|
from cloudgateway.private.asyncio.websocket.aio_parent_process_monitor import AioParentProcessMonitor
|
|
|
|
|
|
try:
|
|
# conditional import which will only work if using splunk's python
|
|
if sys.version_info < (3,0):
|
|
from cloudgateway.splunk.twisted.cluster.cluster_monitor import ClusterMonitor
|
|
else:
|
|
from cloudgateway.splunk.asyncio.cluster.cluster_monitor import ClusterMonitor
|
|
except:
|
|
pass
|
|
|
|
|
|
class WebsocketMode(Enum):
|
|
"""
|
|
Enum for supported modes of initiating websocket connection. Async means the websocket is run via a single threaded
|
|
event loop and callbacks are expected to be non-blocking and return deferred objects. In threaded mode, callbacks
|
|
can be blocking and each message received will be delegated to a thread from a pool.
|
|
"""
|
|
ASYNC = 0
|
|
THREADED = 1
|
|
|
|
|
|
class ServerResponse(object):
|
|
"""
|
|
Class to encapsulate response message to be sent back to either the client or Cloudgateway
|
|
"""
|
|
|
|
@staticmethod
|
|
def create_rid():
|
|
"""
|
|
Helper method to generate a unique guid
|
|
:return:
|
|
"""
|
|
return str(uuid.uuid4())
|
|
|
|
def __init__(self, payload, request_id=None):
|
|
self.payload = payload
|
|
if request_id:
|
|
self.request_id = request_id
|
|
else:
|
|
self.request_id = ServerResponse.create_rid()
|
|
|
|
def __repr__(self):
|
|
return str(
|
|
{
|
|
'payload': self.payload,
|
|
'request_id': self.request_id
|
|
}
|
|
)
|
|
|
|
|
|
class CloudGatewayWsClient(object):
|
|
def __init__(self, encryption_context, message_handler, logger=None, mode=WebsocketMode.THREADED,
|
|
config=SplunkConfig(), shard_id=None, websocket_context=None, key_bundle=None):
|
|
"""
|
|
|
|
Args:
|
|
encryption_context: [EncryptionContext] Can be a regular EncryptionContext or a subclass such
|
|
as SplunkEncryptionContext depending on whether you want to run in standalone mode or not.
|
|
handler can call Splunk APIs if needed
|
|
message_handler: [AbstractMessageHandler] interface specifying how to handle messages from Cloudgateway
|
|
logger: Optional logger parameter for logging purposes
|
|
mode: [WebsocketMode] Enum specifying either threaded mode or async mode. Defaults to Threaded mode
|
|
config: Optional [CloudgatewaySdkConfig] configuration class
|
|
"""
|
|
self.encryption_context = encryption_context
|
|
self.logger = logger or DummyLogger()
|
|
self.logger.info(encryption_context)
|
|
self.message_handler = message_handler
|
|
self.mode = mode
|
|
self.config = config
|
|
self.shard_id = shard_id
|
|
self.key_bundle = key_bundle
|
|
|
|
if self.mode == WebsocketMode.THREADED:
|
|
websocket_mode = constants.THREADED_MODE
|
|
|
|
elif self.mode == WebsocketMode.ASYNC:
|
|
websocket_mode = constants.ASYNC_MODE
|
|
else:
|
|
raise ValueError("Unsupported websocket mode")
|
|
|
|
if self.encryption_context.mode == cloudgateway.private.util.sdk_mode.SdkMode.SPLUNK:
|
|
session_key = self.encryption_context.session_key
|
|
if sys.version_info < (3,0):
|
|
parent_process_monitor = ParentProcessMonitor()
|
|
else:
|
|
parent_process_monitor = AioParentProcessMonitor()
|
|
|
|
cluster_monitor = None
|
|
if self.shard_id is None:
|
|
cluster_monitor = ClusterMonitor(self.logger)
|
|
else:
|
|
session_key = ""
|
|
# parent_process_monitor = ParentProcessMonitor() # For testing only, remove
|
|
parent_process_monitor = None
|
|
cluster_monitor = None
|
|
|
|
self.connector = CloudgatewayConnector(self.message_handler,
|
|
self.encryption_context,
|
|
session_key,
|
|
parent_process_monitor,
|
|
cluster_monitor,
|
|
self.logger,
|
|
self.config,
|
|
mode=websocket_mode,
|
|
shard_id=self.shard_id,
|
|
websocket_context=websocket_context,
|
|
key_bundle=self.key_bundle
|
|
)
|
|
|
|
def connect(self, threadpool_size=None):
|
|
"""
|
|
Initiate websocket connection to cloudgateway
|
|
|
|
Args:
|
|
threadpool_size: [Integer] Size of threadpool to use. Only applies in Threaded Mode.
|
|
|
|
Returns: None
|
|
"""
|
|
|
|
self.connector.connect(threadpool_size)
|
|
|
|
def send(self, device_id, payload, request_id=None):
|
|
"""Send a message to a particular device using Cloud Gateway
|
|
Args:
|
|
device_id ([binary]): id of the device to send the message to
|
|
payload ([string]): Message to be sent. Can be any format you want (json, serialized proto, etc.)
|
|
request_id (str, optional): [description]. Defaults to "123".
|
|
|
|
Returns:
|
|
[requests.response]: response returned by requests object
|
|
"""
|
|
if not request_id:
|
|
request_id = ServerResponse.create_rid()
|
|
|
|
send_message_request = http_pb2.SendMessageRequest()
|
|
recipient_info = self.message_handler.fetch_device_info(device_id)
|
|
|
|
build_encrypted_payload(recipient_info,
|
|
self.encryption_context,
|
|
payload,
|
|
request_id,
|
|
self.logger,
|
|
signed_envelope=send_message_request.signedEnvelope)
|
|
|
|
# self.connector.factory.protocol.sendMessage(send_message_request.SerializeToString())
|
|
|
|
|
|
spacebridge_header = {'Authorization': sb_auth_header(self.encryption_context)}
|
|
|
|
with requests_ssl_context(self.key_bundle) as cert:
|
|
return requests.post(sb_message_endpoint(self.config),
|
|
headers=spacebridge_header,
|
|
data=send_message_request.SerializeToString(),
|
|
cert=cert.name
|
|
)
|
|
|
|
|
|
|
|
class AbstractWebsocketContext(object):
|
|
"""
|
|
Optional context class if you want finer grain control over behaviour of websocket
|
|
"""
|
|
|
|
__metaclass__ = ABCMeta
|
|
|
|
RETRY_INTERVAL_SECONDS = 2
|
|
|
|
@abstractmethod
|
|
def on_open(self, protocol):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def on_ping(self, payload, protocol):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def on_pong(self, payload, protocol):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def on_close(self, wasClean, code, reason, protocol):
|
|
pass
|
|
|
|
|
|
class AbstractMessageHandler(object):
|
|
"""
|
|
Used to delegate methods necessary when a message is received from Cloudgateway.
|
|
"""
|
|
__metaclass__ = ABCMeta
|
|
|
|
def __init__(self, encryption_context):
|
|
"""
|
|
If you override the constructor, make sure to call this constructor, the session key is necessary to construct
|
|
the encryption context necessary for encrypting and decrypting messages
|
|
Args:
|
|
encryption_context: [EncryptionContext] object which is necessary for message handler to decrypt incoming
|
|
messages
|
|
"""
|
|
self.system_auth_header = ""
|
|
self.encryption_context = encryption_context
|
|
|
|
if hasattr(encryption_context, 'session_key'):
|
|
self.system_auth_header = SplunkAuthHeader(encryption_context.session_key)
|
|
|
|
@abstractmethod
|
|
def handle_application_message(self, msg, sender, request_id):
|
|
"""
|
|
|
|
Args:
|
|
msg: message payload sent by a client device
|
|
sender: byte array of device sending message
|
|
request_id: string representing unique identifier of message sent by the sender
|
|
|
|
Returns: ServerResponse or List[ServerResponse] which represent payloads to be sent back to the client device.
|
|
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def handle_cloudgateway_message(self, msg):
|
|
"""
|
|
|
|
Args:
|
|
msg: message payload strinng sent by cloud gatewayy
|
|
|
|
Returns: ServerResponse or List[ServerResponse]
|
|
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def fetch_device_info(self, device_id):
|
|
"""
|
|
Given device id, fetch DeviceInfo object associated to that device
|
|
|
|
Args:
|
|
device_id: byte array representing device id
|
|
|
|
Returns: DeviceInfo object
|
|
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def decrypt_session_token(self, encrypted_session_token):
|
|
"""
|
|
|
|
Args:
|
|
encrypted_session_token: An encrypted session token string
|
|
|
|
Returns: Decrypted session tokenn string
|
|
|
|
"""
|
|
public_key = self.encryption_context.encrypt_public_key()
|
|
private_key = self.encryption_context.encrypt_private_key()
|
|
raw_token = base64.b64decode(encrypted_session_token)
|
|
return decrypt_session_token(self.encryption_context.sodium_client, raw_token, public_key, private_key)
|
|
|
|
|