Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions hashtopolis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

# models
from .hashtopolis import (
ApiToken,
AccessGroup,
Agent,
AgentStat,
Expand Down
101 changes: 66 additions & 35 deletions hashtopolis/hashtopolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ def __init__(self):
self.username = self._cfg['username']
self.password = self._cfg['password']

@classmethod
def with_credentials(cls, uri, username, password):
"""Create a config with explicit credentials instead of reading from a config file."""
config = cls.__new__(cls)
config._hashtopolis_uri = uri
config._api_endpoint = uri + '/api/v2'
config.username = username
config.password = password
return config


class HashtopolisResponseError(HashtopolisError):
pass
Expand Down Expand Up @@ -106,22 +116,34 @@ def __init__(self, model_uri, config):
self._hashtopolis_uri = config._hashtopolis_uri
self.config = config

def authenticate(self):
if self._api_endpoint not in HashtopolisConnector.token:
# Request access TOKEN, used throughout the test
def authenticate(self, auth=None):
"""
Authenticate with the API and store the token for future requests.

logger.info("Start authentication")
Args:
auth: Authentication object understood by requests, typically a
``(username, password)`` tuple. Is only used for one off authentication
that differ from the config. This authentication is not cached.
"""
if auth is not None:
logger.info("Start authentication with provided credentials")
auth_uri = self._api_endpoint + '/auth/token'
auth = (self.config.username, self.config.password)
r = requests.post(auth_uri, auth=auth)
self.validate_status_code(r, [201], "Authentication failed")

r_json = self.resp_to_json(r)
HashtopolisConnector.token[self._api_endpoint] = r_json['token']
HashtopolisConnector.token_expires[self._api_endpoint] = r_json['token']

self._token = HashtopolisConnector.token[self._api_endpoint]
self._token_expires = HashtopolisConnector.token_expires[self._api_endpoint]
self._token = r_json['token']
self._token_expires = r_json['token']
Comment thread
jessevz marked this conversation as resolved.
else:
if self._api_endpoint not in HashtopolisConnector.token:
logger.info("Start authentication")
auth_uri = self._api_endpoint + '/auth/token'
r = requests.post(auth_uri, auth=(self.config.username, self.config.password))
self.validate_status_code(r, [201], "Authentication failed")
r_json = self.resp_to_json(r)
HashtopolisConnector.token[self._api_endpoint] = r_json['token']
HashtopolisConnector.token_expires[self._api_endpoint] = r_json['token']
self._token = HashtopolisConnector.token[self._api_endpoint]
self._token_expires = HashtopolisConnector.token_expires[self._api_endpoint]

self._headers = {
'Authorization': 'Bearer ' + self._token
Expand Down Expand Up @@ -190,9 +212,9 @@ def validate_status_code(self, r, expected_status_code, error_msg):
# query_params = urllib.parse.parse_qs(urllib.parse.urlparse(links["last"]).query)
# TODO not really a straightforward way to validate the last link

def get_single_page(self, page, filter):
def get_single_page(self, page, filter, auth=None):
"""Gets a single page by using the page parameters"""
self.authenticate()
self.authenticate(auth=auth)
headers = self._headers
request_uri = self._api_endpoint + self._model_uri
payload = {}
Expand All @@ -215,8 +237,8 @@ def get_single_page(self, page, filter):
return response["data"]

# todo refactor start_offset into page variable
def filter(self, include, ordering, filter, start_offset):
self.authenticate()
def filter(self, include, ordering, filter, start_offset, auth=None):
self.authenticate(auth=auth)
headers = self._headers
Comment thread
jessevz marked this conversation as resolved.

after_dict = {"primary": {"id": start_offset}}
Expand Down Expand Up @@ -253,8 +275,8 @@ def filter(self, include, ordering, filter, start_offset):
break
request_uri = response['links']['next']

def get_one(self, pk, include):
self.authenticate()
def get_one(self, pk, include, auth=None):
self.authenticate(auth=auth)
uri = self._api_endpoint + self._model_uri + f'/{pk}'
headers = self._headers

Expand All @@ -266,8 +288,8 @@ def get_one(self, pk, include):
self.validate_status_code(r, [200], "Get single object failed")
return self.resp_to_json(r)

def delete_many(self, objects):
self.authenticate()
def delete_many(self, objects, auth=None):
self.authenticate(auth=auth)
uri = self._api_endpoint + self._model_uri
headers = self._headers
headers['Content-Type'] = 'application/json'
Expand All @@ -282,7 +304,7 @@ def delete_many(self, objects):
r = requests.delete(uri, headers=headers, data=json.dumps(payload))
self.validate_status_code(r, [204], "deleting failed")

def patch_many(self, objects, attributes, field):
def patch_many(self, objects, attributes, field, auth=None):
"""
Used to test PATCH many endpoint.

Expand All @@ -293,7 +315,7 @@ def patch_many(self, objects, attributes, field):
patched with attributes[0] on the set field
"""
assert len(objects) == len(attributes)
self.authenticate()
self.authenticate(auth=auth)
uri = self._api_endpoint + self._model_uri
headers = self._headers
headers['Content-Type'] = 'application/json'
Expand All @@ -302,12 +324,12 @@ def patch_many(self, objects, attributes, field):
r = requests.patch(uri, headers=headers, data=json.dumps(payload))
self.validate_status_code(r, [200], "Patching failed")

def patch_one(self, obj):
def patch_one(self, obj, auth=None):
if not obj.has_changed():
logger.debug("Object '%s' has not changed, no PATCH required", obj)
return

self.authenticate()
self.authenticate(auth=auth)
uri = self._hashtopolis_uri + obj.uri
headers = self._headers
headers['Content-Type'] = 'application/json'
Expand All @@ -325,29 +347,29 @@ def patch_one(self, obj):
# TODO: Validate if return objects matches digital twin
obj.set_initial(self.resp_to_json(r)['data'].copy())

def send_patch(self, uri, data):
self.authenticate()
def send_patch(self, uri, data, auth=None):
self.authenticate(auth=auth)
headers = self._headers
headers['Content-Type'] = 'application/json'
logger.debug("Sending PATCH payload: %s to %s", json.dumps(data), uri)
r = requests.patch(uri, headers=headers, data=json.dumps(data))
self.validate_status_code(r, [204], "Patching failed")

def patch_to_many_relationships(self, obj):
def patch_to_many_relationships(self, obj, auth=None):
for k, v in obj.diff_includes().items():
attributes = []
logger.debug("Going to patch object '%s' property '%s' from '%s' to '%s'", obj, k, v[0], v[1])
for include_id in v[1]:
attributes.append({"type": k, "id": include_id})
data = {"data": attributes}
uri = self._hashtopolis_uri + obj.uri + "/relationships/" + k
self.send_patch(uri, data)
self.send_patch(uri, data, auth=auth)

def create(self, obj):
def create(self, obj, auth=None):
# Check if object to be created is new
assert obj._new_model is True

self.authenticate()
self.authenticate(auth=auth)
uri = self._api_endpoint + self._model_uri
headers = self._headers
headers['Content-Type'] = 'application/json'
Expand All @@ -362,12 +384,12 @@ def create(self, obj):
# TODO: Validate if return objects matches digital twin
obj.set_initial(self.resp_to_json(r)['data'].copy())

def delete(self, obj):
def delete(self, obj, auth=None):
""" Delete object from database """
# TODO: Check if object to be deleted actually exists
assert obj._new_model is False

self.authenticate()
self.authenticate(auth=auth)
uri = self._hashtopolis_uri + obj.uri
headers = self._headers
payload = {}
Expand All @@ -377,8 +399,8 @@ def delete(self, obj):

# TODO: Cleanup object to allow re-creation

def count(self, filter):
self.authenticate()
def count(self, filter, auth=None):
self.authenticate(auth=auth)
uri = self._api_endpoint + self._model_uri + "/count"
headers = self._headers
payload = {}
Expand All @@ -394,12 +416,13 @@ def count(self, filter):

# Build Django ORM style django.query interface
class QuerySet():
def __init__(self, cls, include=None, ordering=None, filters=None, pages=None):
def __init__(self, cls, include=None, ordering=None, filters=None, pages=None, auth=None):
self.cls = cls
self.include = include
self.ordering = ordering
self.filters = filters
self.pages = pages
self.auth = auth

def __iter__(self):
yield from self.__getitem__(slice(None, None, 1))
Expand Down Expand Up @@ -431,7 +454,7 @@ def filter_(self, start, stop, step):
filters['id'] = filters['pk']
del filters['pk']

filter_generator = self.cls.get_conn().filter(self.include, self.ordering, filters, start_offset=cursor)
filter_generator = self.cls.get_conn().filter(self.include, self.ordering, filters, start_offset=cursor, auth=self.auth)

while index < stop:
# Fetch new entries in chunks default to server
Expand Down Expand Up @@ -469,6 +492,10 @@ def page(self, **pages):
def all(self):
# yield from self
return self

def authenticate(self, auth):
self.auth = auth
return self

Comment thread
jessevz marked this conversation as resolved.
def get(self, **filters):
if filters:
Expand Down Expand Up @@ -760,6 +787,10 @@ def uri(self):
##
# Begin of API objects
#
class ApiToken(Model, uri="/ui/apiTokens"):
pass


class AccessGroup(Model, uri="/ui/accessgroups"):
pass

Expand Down