Browse Source

move all get/store logic out of BaseStorage

Thus, BaseStorage.data got removed as iterating manually over the
available data is not possible any more.
Helge Jung 9 years ago
parent
commit
b3c76d4d0e
3 changed files with 100 additions and 79 deletions
  1. 1 1
      batcave.py
  2. 50 76
      ffstatus/basestorage.py
  3. 49 2
      ffstatus/filestorage.py

+ 1 - 1
batcave.py

@@ -166,7 +166,7 @@ def main():
             storage.merge_new_data(newdata)
             storage.save()
 
-            logger.debug('I have data for %d nodes.', len(storage.data))
+            logger.debug('I have data for %d nodes.', storage.status['nodes'])
 
         except Exception as err:
             logger.error(str(err))

+ 50 - 76
ffstatus/basestorage.py

@@ -30,27 +30,6 @@ class BaseStorage(object):
     to a file, database, whatever.
     """
 
-    DATAKEY_VPN = '__VPN__'
-    FIELDKEY_UPDATED = '__UPDATED__'
-
-    __data = None
-
-    @property
-    def data(self):
-        """Contains the data handled by this storage."""
-        return self.__data
-
-    def init_data(self, data):
-        """setter for data property"""
-        if self.__data is not None:
-            msg = 'Tried to initialize storage data a second time.'
-            logging.error(msg)
-            raise Exception(msg)
-
-        logging.debug('Setting initial storage data (%d items).',
-                      len(data) if data is not None else 0)
-        self.__data = data
-
     def open(self):
         """
         When overridden in a subclass,
@@ -72,14 +51,6 @@ class BaseStorage(object):
         """
         pass
 
-    def set_node_data(self, node_id, data):
-        """
-        Sets the node's data.
-        This method should be overriden in a subclass,
-        but still call the parent one.
-        """
-        self.__data[node_id] = data
-
     @property
     def status(self):
         """Gets status information on the storage."""
@@ -175,19 +146,20 @@ class BaseStorage(object):
     def get_nodes(self, sortby=None, include_raw_data=False):
         """Gets a list of all known nodes."""
 
-        sorted_ids = self.data.keys()
+        nodes = self.get_all_nodes_raw()
+        sorted_ids = [x for x in nodes]
         if sortby is not None:
             if sortby == 'name':
-                sortkey = lambda x: self.data[x]['hostname'].lower()
-                sorted_ids = sorted(self.data, key=sortkey)
+                sortkey = lambda x: nodes[x]['hostname'].lower()
+                sorted_ids = sorted(sorted_ids, key=sortkey)
             elif sortby == 'id':
-                sorted_ids = sorted(self.data)
+                sorted_ids = sorted(sorted_ids)
 
         result = []
         for nodeid in sorted_ids:
             if nodeid.startswith('__'):
                 continue
-            node = sanitize_node(self.data[nodeid], include_raw_data)
+            node = sanitize_node(nodes[nodeid], include_raw_data)
             result.append(node)
 
         return result
@@ -260,25 +232,43 @@ class BaseStorage(object):
         else:
             return 'offline'
 
+    def set_node_data(self, key, data):
+        """Overwrite data for the node with the given key."""
+        raise NotImplementedError("set_node_data was not overridden")
+
+    def check_vpn_key(self, key):
+        if key is None or re.match(r'^[a-fA-F0-9]+$', key) is None:
+            raise VpnKeyFormatError(key)
+
+    def get_vpn_keys(self):
+        """Gets a list of VPN keys."""
+        raise NotImplementedError("get_vpn_keys was not overriden")
+
+    def get_vpn_item(self, key, create=False):
+        self.check_vpn_key(key)
+        raise NotImplementedError("store_vpn_item was not overriden")
+
+    def store_vpn_item(self, key, data):
+        raise NotImplementedError("store_vpn_item was not overriden")
+
     def resolve_vpn_remotes(self):
-        if not self.DATAKEY_VPN in self.data:
-            return
+        """Iterates all remotes and resolves IP blocks not yet resolved."""
+        vpn = self.get_vpn_keys()
 
-        vpn = self.data[self.DATAKEY_VPN]
         init_vpn_cache = {}
         for key in vpn:
-            if not isinstance(vpn[key], dict):
-                continue
+            entry = self.get_vpn_item(key)
+            entry_modified = False
 
-            for mode in vpn[key]:
-                if not isinstance(vpn[key][mode], dict):
+            for mode in entry:
+                if not isinstance(entry[mode], dict):
                     continue
 
-                for gateway in vpn[key][mode]:
-                    if not isinstance(vpn[key][mode][gateway], dict):
+                for gateway in entry[mode]:
+                    if not isinstance(entry[mode][gateway], dict):
                         continue
 
-                    item = vpn[key][mode][gateway]
+                    item = entry[mode][gateway]
                     if 'remote' in item and not 'remote_raw' in item:
                         item['remote_raw'] = item['remote']
                         resolved = None
@@ -296,48 +286,28 @@ class BaseStorage(object):
 
                         if resolved is not None:
                             item['remote'] = resolved
+                            entry_modified = True
 
-        self.save()
-
-    def __get_vpn_item(self, key, create=False):
-        if key is None or re.match(r'^[a-fA-F0-9]+$', key) is None:
-            raise VpnKeyFormatError(key)
-            return
-
-        if not self.DATAKEY_VPN in self.data:
-            if not create:
-                return None
-            self.data[self.DATAKEY_VPN] = {}
-
-        if not key in self.data[self.DATAKEY_VPN]:
-            if not create:
-                return None
-            self.data[self.DATAKEY_VPN][key] = {'active': {}, 'last': {}}
-
-        return self.data[self.DATAKEY_VPN][key]
+            if entry_modified:
+                self.store_vpn_item(key, entry)
 
     def get_vpn_gateways(self):
-        if not self.DATAKEY_VPN in self.data:
-            return []
-
         gateways = set()
-        vpn = self.data[self.DATAKEY_VPN]
+        vpn = self.get_vpn_keys()
         for key in vpn:
-            for conntype in vpn[key]:
-                for gateway in vpn[key][conntype]:
+            entry = self.get_vpn_item(key)
+            for conntype in entry:
+                for gateway in entry[conntype]:
                     gateways.add(gateway)
 
         return sorted(gateways)
 
     def get_vpn_connections(self):
-        if not self.DATAKEY_VPN in self.data:
-            return []
-
         conntypes = ['active', 'last']
         result = []
-        vpn = self.data[self.DATAKEY_VPN]
-        for key in vpn:
-            vpn_entry = vpn[key]
+        vpnkeys = self.get_vpn_keys()
+        for key in vpnkeys:
+            vpn_entry = self.get_vpn_item(key)
             if not isinstance(vpn_entry, dict):
                 continue
 
@@ -369,7 +339,7 @@ class BaseStorage(object):
         return result
 
     def log_vpn_connect(self, key, peername, remote, gateway, timestamp):
-        item = self.__get_vpn_item(key, create=True)
+        item = self.get_vpn_item(key, create=True)
 
         # resolve remote addr to its netblock
         remote_raw = remote
@@ -390,8 +360,10 @@ class BaseStorage(object):
             'remote_raw': remote_raw,
         }
 
+        self.store_vpn_item(key, item)
+
     def log_vpn_disconnect(self, key, gateway, timestamp):
-        item = self.__get_vpn_item(key, create=True)
+        item = self.get_vpn_item(key, create=True)
 
         active = {}
         if gateway in item['active']:
@@ -400,3 +372,5 @@ class BaseStorage(object):
 
         active['disestablish'] = timestamp
         item['last'][gateway] = active
+
+        self.store_vpn_item(key, item)

+ 49 - 2
ffstatus/filestorage.py

@@ -12,6 +12,11 @@ from .basestorage import BaseStorage
 class FileStorage(BaseStorage):
     """Provides file-based persistency for BaseStorage"""
 
+    DATAKEY_VPN = '__VPN__'
+    FIELDKEY_UPDATED = '__UPDATED__'
+
+    __data = None
+
     def __init__(self, storage_dir):
         self.logger = logging.getLogger('Storage')
         self.storage_dir = storage_dir
@@ -43,7 +48,7 @@ class FileStorage(BaseStorage):
             raise err
 
         self.logger.info('Opened storage with %d entries.', len(loaded_data))
-        self.init_data(loaded_data)
+        self.__data = loaded_data
 
     def save(self):
         if self.storage_file is None:
@@ -51,7 +56,7 @@ class FileStorage(BaseStorage):
 
         self.storage_file.seek(0, os.SEEK_SET)
         self.storage_file.truncate()
-        pickle.dump(self.data, self.storage_file, protocol=2)
+        pickle.dump(self.__data, self.storage_file, protocol=2)
         self.storage_file.flush()
 
         # make an auto-backup of the just-written file
@@ -64,3 +69,45 @@ class FileStorage(BaseStorage):
         self.storage_file.close()
         self.storage_file = None
         BaseStorage.close(self)
+
+    def get_all_nodes_raw(self):
+        """Gets all nodes as dict."""
+
+        result = {}
+        for nodeid in self.__data:
+            if nodeid.startswith('__'):
+                continue
+            node = self.__data[nodeid]
+            result[nodeid] = node
+
+        return result
+
+    def set_node_data(self, key, data):
+        self.__data[key] = data
+
+    def get_vpn_keys(self):
+        """Gets a list of VPN keys."""
+        if not self.DATAKEY_VPN in self.__data:
+            return {}
+
+        vpn = [key for key in self.__data[self.DATAKEY_VPN]]
+        return vpn
+
+    def get_vpn_item(self, key, create=False):
+        """Get the VPN entry with the specified key."""
+        if not self.DATAKEY_VPN in self.__data:
+            if not create:
+                return None
+            self.__data[self.DATAKEY_VPN] = {}
+
+        if not key in self.__data[self.DATAKEY_VPN]:
+            if not create:
+                return None
+            self.__data[self.DATAKEY_VPN][key] = {'active': {}, 'last': {}}
+
+        return self.__data[self.DATAKEY_VPN][key]
+
+    def store_vpn_item(self, key, data):
+        if not self.DATAKEY_VPN in self.__data:
+            self.__data[self.DATAKEY_VPN] = {}
+        self.__data[self.DATAKEY_VPN][key] = data