|
@@ -0,0 +1,95 @@
|
|
|
+#!/usr/bin/python
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
+
|
|
|
+import json
|
|
|
+import logging
|
|
|
+import redis
|
|
|
+
|
|
|
+from .basestorage import BaseStorage
|
|
|
+
|
|
|
+
|
|
|
+class RedisStorage(BaseStorage):
|
|
|
+ """Redis-backed storage"""
|
|
|
+
|
|
|
+ FIELDSUFFIX_TYPE = '__TYPE'
|
|
|
+
|
|
|
+ def __init__(self,
|
|
|
+ redis_host='127.0.0.1', redis_port=6379, redis_password=None):
|
|
|
+ self.logger = logging.getLogger('RedisStorage')
|
|
|
+
|
|
|
+ self.logger.debug('Connecting to REDIS database at %s on port %n.',
|
|
|
+ redis_host, redis_port)
|
|
|
+ self.db = redis.StrictRedis(host=redis_host, port=redis_port,
|
|
|
+ password=redis_password)
|
|
|
+ self.logger.info('Connected to REDIS database with %n entries.',
|
|
|
+ self.db.dbsize())
|
|
|
+
|
|
|
+ def save(self):
|
|
|
+ self.db.save()
|
|
|
+
|
|
|
+ def get_all_nodes_raw(self):
|
|
|
+ keys = self.db.keys('node_*')
|
|
|
+ nodes = {}
|
|
|
+ for key in keys:
|
|
|
+ node_id = key[5:]
|
|
|
+ node = self.get_node_data(node_id)
|
|
|
+ nodes[node_id] = node
|
|
|
+ return nodes
|
|
|
+
|
|
|
+ def set_node_data(self, key, data):
|
|
|
+ thedata = {}
|
|
|
+ for item in data:
|
|
|
+ payload = data[item]
|
|
|
+ if isinstance(payload, basestring):
|
|
|
+ thedata[item] = data[item]
|
|
|
+ thedata[item + self.FIELDSUFFIX_TYPE] = 'str'
|
|
|
+ elif isinstance(payload, int):
|
|
|
+ thedata[item] = str(data[item])
|
|
|
+ thedata[item + self.FIELDSUFFIX_TYPE] = 'int'
|
|
|
+ elif isinstance(payload, float):
|
|
|
+ thedata[item] = str(data[item])
|
|
|
+ thedata[item + self.FIELDSUFFIX_TYPE] = 'float'
|
|
|
+ else:
|
|
|
+ thedata[item] = json.dumps(data[item])
|
|
|
+ thedata[item + self.FIELDSUFFIX_TYPE] = 'json'
|
|
|
+ self.db.hmset('node_' + key, thedata)
|
|
|
+
|
|
|
+ def get_node_data(self, key):
|
|
|
+ node = {}
|
|
|
+ thedata = self.db.hgetall('node_' + key)
|
|
|
+ for item in thedata:
|
|
|
+ if item.endswith(self.FIELDSUFFIX_TYPE):
|
|
|
+ continue
|
|
|
+
|
|
|
+ fieldtype = thedata.get(item + self.FIELDSUFFIX_TYPE, 'str')
|
|
|
+ payload = thedata[item]
|
|
|
+
|
|
|
+ if fieldtype == 'json':
|
|
|
+ node[item] = json.loads(payload)
|
|
|
+ elif fieldtype == 'int':
|
|
|
+ node[item] = int()
|
|
|
+ elif fieldtype == 'float':
|
|
|
+ node[item] = float()
|
|
|
+ else:
|
|
|
+ node[item] = payload
|
|
|
+
|
|
|
+ return node
|
|
|
+
|
|
|
+ def get_vpn_keys(self):
|
|
|
+ keys = [key[4:] for key in self.db.keys('vpn_*')]
|
|
|
+ return keys
|
|
|
+
|
|
|
+ def get_vpn_item(self, key, create=False):
|
|
|
+ self.check_vpn_key(key)
|
|
|
+ rawdata = self.db.get('vpn_' + key)
|
|
|
+ if rawdata is None:
|
|
|
+ if not create:
|
|
|
+ return None
|
|
|
+ self.store_vpn_item(key, {'active': {}, 'last': {}})
|
|
|
+ rawdata = self.db.get('vpn_' + key)
|
|
|
+ data = json.loads(rawdata)
|
|
|
+ return data
|
|
|
+
|
|
|
+ def store_vpn_item(self, key, data):
|
|
|
+ self.check_vpn_key(key)
|
|
|
+ self.db.set('vpn_' + key, json.dumps(data))
|