diff --git a/nilmdb/server/bulkdata.py b/nilmdb/server/bulkdata.py index 01e7adf..983ce31 100644 --- a/nilmdb/server/bulkdata.py +++ b/nilmdb/server/bulkdata.py @@ -25,9 +25,9 @@ fd_cache_size = 8 @nilmdb.utils.must_close(wrap_verify = False) class BulkData(object): def __init__(self, basepath, **kwargs): - self.basepath = basepath - self.root = os.path.join(self.basepath, "data") - self.lock = self.root + ".lock" + self.basepath = self._encode_filename(basepath) + self.root = os.path.join(self.basepath, b"data") + self.lock = self.root + b".lock" self.lockfile = None # Tuneables @@ -56,7 +56,8 @@ class BulkData(object): # Create the lock self.lockfile = open(self.lock, "w") if not nilmdb.utils.lock.exclusive_lock(self.lockfile): - raise IOError('database at "' + self.basepath + + raise IOError('database at "' + + self._decode_filename(self.basepath) + '" is already locked by another process') def close(self): @@ -71,16 +72,20 @@ class BulkData(object): self.lockfile = None def _encode_filename(self, path): - # Encode all paths to UTF-8, regardless of sys.getfilesystemencoding(), - # because we want to be able to represent all code points and the user - # will never be directly exposed to filenames. We can then do path - # manipulations on the UTF-8 directly. + # Translate unicode strings to raw bytes, if needed. We + # always manipulate paths internally as bytes. if isinstance(path, str): return path.encode('utf-8') return path + def _decode_filename(self, path): + # Translate raw bytes to unicode strings, escaping if needed + if isinstance(path, bytes): + return path.decode('utf-8', errors='backslashreplace') + return path + def _create_check_ospath(self, ospath): - if ospath[-1] == '/': + if ospath[-1] == b'/': raise ValueError("invalid path; should not end with a /") if Table.exists(ospath): raise ValueError("stream already exists at this path") @@ -97,13 +102,13 @@ class BulkData(object): don't exist. Returns a list of elements that got created.""" path = self._encode_filename(unicodepath) - if path[0] != '/': - raise ValueError("paths must start with /") - [ group, node ] = path.rsplit("/", 1) - if group == '': + if path[0:1] != b'/': + raise ValueError("paths must start with / ") + [ group, node ] = path.rsplit(b"/", 1) + if group == b'': raise ValueError("invalid path; path must contain at least one " "folder") - if node == '': + if node == b'': raise ValueError("invalid path; should not end with a /") if not Table.valid_path(path): raise ValueError("path name is invalid or contains reserved words") @@ -114,7 +119,7 @@ class BulkData(object): # os.path.join) # Make directories leading up to this one - elements = path.lstrip('/').split('/') + elements = path.lstrip(b'/').split(b'/') made_dirs = [] try: # Make parent elements @@ -176,7 +181,7 @@ class BulkData(object): def _remove_leaves(self, unicodepath): """Remove empty directories starting at the leaves of unicodepath""" path = self._encode_filename(unicodepath) - elements = path.lstrip('/').split('/') + elements = path.lstrip(b'/').split(b'/') for i in reversed(list(range(len(elements)))): ospath = os.path.join(self.root, *elements[0:i+1]) try: @@ -191,9 +196,9 @@ class BulkData(object): newpath = self._encode_filename(newunicodepath) # Get OS paths - oldelements = oldpath.lstrip('/').split('/') + oldelements = oldpath.lstrip(b'/').split(b'/') oldospath = os.path.join(self.root, *oldelements) - newelements = newpath.lstrip('/').split('/') + newelements = newpath.lstrip(b'/').split(b'/') newospath = os.path.join(self.root, *newelements) # Basic checks @@ -204,8 +209,8 @@ class BulkData(object): self.getnode.cache_remove(self, oldunicodepath) # Move the table to a temporary location - tmpdir = tempfile.mkdtemp(prefix = "rename-", dir = self.root) - tmppath = os.path.join(tmpdir, "table") + tmpdir = tempfile.mkdtemp(prefix = b"rename-", dir = self.root) + tmppath = os.path.join(tmpdir, b"table") os.rename(oldospath, tmppath) try: @@ -233,7 +238,7 @@ class BulkData(object): path = self._encode_filename(unicodepath) # Get OS path - elements = path.lstrip('/').split('/') + elements = path.lstrip(b'/').split(b'/') ospath = os.path.join(self.root, *elements) # Remove Table object from cache @@ -258,7 +263,7 @@ class BulkData(object): """Return a Table object corresponding to the given database path, which must exist.""" path = self._encode_filename(unicodepath) - elements = path.lstrip('/').split('/') + elements = path.lstrip(b'/').split(b'/') ospath = os.path.join(self.root, *elements) return Table(ospath, self.initial_nrows) @@ -271,12 +276,12 @@ class Table(object): @classmethod def valid_path(cls, root): """Return True if a root path is a valid name""" - return "_format" not in root.split("/") + return "_format" not in root.split(b"/") @classmethod def exists(cls, root): """Return True if a table appears to exist at this OS path""" - return os.path.isfile(os.path.join(root, "_format")) + return os.path.isfile(os.path.join(root, b"_format")) @classmethod def create(cls, root, layout, file_size, files_per_dir): @@ -293,7 +298,7 @@ class Table(object): "files_per_dir": files_per_dir, "layout": layout, "version": 3 } - with open(os.path.join(root, "_format"), "wb") as f: + with open(os.path.join(root, b"_format"), "wb") as f: pickle.dump(fmt, f, 2) # Normal methods @@ -303,7 +308,7 @@ class Table(object): self.initial_nrows = initial_nrows # Load the format - with open(os.path.join(self.root, "_format"), "rb") as f: + with open(os.path.join(self.root, b"_format"), "rb") as f: fmt = pickle.load(f) if fmt["version"] != 3: # pragma: no cover @@ -336,7 +341,7 @@ class Table(object): # greater than the row number of any piece of data that # currently exists, not necessarily all data that _ever_ # existed. - regex = re.compile("^[0-9a-f]{4,}$") + regex = re.compile(b"^[0-9a-f]{4,}$") # Find the last directory. We sort and loop through all of them, # starting with the numerically greatest, because the dirs could be @@ -380,8 +385,8 @@ class Table(object): filenum = row // self.rows_per_file # It's OK if these format specifiers are too short; the filenames # will just get longer but will still sort correctly. - dirname = sprintf("%04x", filenum // self.files_per_dir) - filename = sprintf("%04x", filenum % self.files_per_dir) + dirname = sprintf(b"%04x", filenum // self.files_per_dir) + filename = sprintf(b"%04x", filenum % self.files_per_dir) offset = (row % self.rows_per_file) * self.row_size count = self.rows_per_file - (row % self.rows_per_file) return (dirname, filename, offset, count) @@ -533,7 +538,7 @@ class Table(object): ret.append(f.extract_string(offset, count)) remaining -= count row += count - return b"".join(ret) + return "".join(ret) def __getitem__(self, row): """Extract timestamps from a row, with table[n] notation.""" @@ -556,7 +561,7 @@ class Table(object): # file. Only when the list covers the entire extent of the # file will that file be removed. datafile = os.path.join(self.root, subdir, filename) - cachefile = datafile + ".removed" + cachefile = datafile + b".removed" try: with open(cachefile, "rb") as f: ranges = pickle.load(f) diff --git a/nilmdb/server/rocket.c b/nilmdb/server/rocket.c index 3cf3b95..4d80afa 100644 --- a/nilmdb/server/rocket.c +++ b/nilmdb/server/rocket.c @@ -160,13 +160,19 @@ static PyObject *Rocket_new(PyTypeObject *type, PyObject *args, PyObject *kwds) static int Rocket_init(Rocket *self, PyObject *args, PyObject *kwds) { const char *layout, *path; + int pathlen; static char *kwlist[] = { "layout", "file", NULL }; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "sz", kwlist, - &layout, &path)) + if (!PyArg_ParseTupleAndKeywords(args, kwds, "sz#", kwlist, + &layout, &path, &pathlen)) return -1; if (!layout) return -1; if (path) { + if (strlen(path) != pathlen) { + PyErr_SetString(PyExc_ValueError, "path must not " + "contain NUL characters"); + return -1; + } if ((self->file = fopen(path, "a+b")) == NULL) { PyErr_SetFromErrno(PyExc_OSError); return -1; @@ -585,7 +591,7 @@ static PyObject *Rocket_extract_string(Rocket *self, PyObject *args) str[len++] = '\n'; } - PyObject *pystr = PyBytes_FromStringAndSize(str, len); + PyObject *pystr = PyUnicode_FromStringAndSize(str, len); free(str); return pystr; err: