#!/usr/bin/python # Copyright (c) 2010 Peter Palfrader # # Permission is hereby granted, free of charge, to any person obtaining # a copy of this software and associated documentation files (the # "Software"), to deal in the Software without restriction, including # without limitation the rights to use, copy, modify, merge, publish, # distribute, sublicense, and/or sell copies of the Software, and to # permit persons to whom the Software is furnished to do so, subject to # the following conditions: # # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE # LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # an ssh command wrapper, # # stores a file supplied by the calling host. We use this for postgres # backungs, storing both base backups and WAL files. # import sys import os import optparse import re import subprocess import syslog import tempfile import stat import hashlib basedir = '/srv/backups' accepted_fileclasses = ['pg'] block_size = 4096 syslog.openlog(sys.argv[0], syslog.LOG_PID, syslog.LOG_DAEMON) # Usage: debbackup-ssh-wrap [] # via ssh orig command: store-file # retrieve-file def info(m): syslog.syslog(syslog.LOG_INFO, m) def croak(m): syslog.syslog(syslog.LOG_WARNING, m) print >> sys.stderr, m sys.exit(1) def filename_sanity_check(fn): if re.search("[^a-zA-Z0-9._-]", fn): croak("Invalid characters encountered in '%s'."%(fn)) def get_classdir(file_class): d = os.path.join(basedir, file_class) if not os.path.exists(d): croak("Classdir '%s' does not exist."%(d)) return d def get_targetdir(classdir, host, create=False): d = os.path.join(classdir, host) if not os.path.exists(d): if create: info("Creating %s"%(d)) os.mkdir(d) else: croak("Targetdir '%s' does not exist."%(d)) return d def sha512_for_file(fn): d = hashlib.sha512() f = open(fn) while True: data = f.read(block_size) if not data: break d.update(data) f.close() return d.hexdigest() def store_file(host, remote_args): # if len(remote_args) != 4: croak("Exactly four arguments expected for store-file.") sys.exit(1) (fileclass, filename, size, checksum) = remote_args # check fileclass if not fileclass in accepted_fileclasses: croak("Invalid file class '%s'"%(fileclass)) # check filename filename_sanity_check(filename) # check and convert size try: size = int(size) except ValueError: croak("Invalid size argument '%s'"%(size)) # check checksum if not re.match("^[a-f0-9]{128}$", checksum): croak("Invalid checksum argument '%s'."%(checksum)) classdir = get_classdir(fileclass) targetdir = get_targetdir(classdir, host, True) target = os.path.join(targetdir, filename) if os.path.exists(target): checksum_on_disk = sha512_for_file(target) size_on_disk = os.stat(target)[stat.ST_SIZE] if size_on_disk == size and checksum_on_disk == checksum: info("Target '%s' already exists, with same size and checksum (%d, %s)."%(target, size, checksum)) sys.exit(0) else: croak("Target '%s' already exists and has different size or checksum (%d vs %d; %s vs %s)."%(target, size_on_disk,size, checksum_on_disk, checksum)) tmp = tempfile.NamedTemporaryFile(dir=classdir, suffix=".%s.%s"%(host,filename)) info("Receiving remote %s from %s to %stmp (%s bytes)"%(filename, host, tmp.name, size)) running_size = 0 digest = hashlib.sha512() while True: buf = sys.stdin.read(block_size) if not buf: break digest.update(buf) tmp.write(buf) running_size += len(buf) if running_size > size: croak("Size mismatch") tmp.flush() file_size = os.stat(tmp.name)[stat.ST_SIZE] if file_size != size: croak("Size mismatch") if file_size != running_size: croak("Size mismatch. WTF.") if checksum != digest.hexdigest(): croak("Checksum mismatch. WTF.") try: os.link(tmp.name, target) except Exception, e: croak("Failed at linking to target: %s"%(e)) tmp.close() info("Successfully stored %s"%(target)) def retrieve_file(host, remote_args, allowed_reads): # if len(remote_args) != 3: croak("Exactly three arguments expected for retrieve-file.") sys.exit(1) (fileclass, from_host, filename) = remote_args # check fileclass if not fileclass in accepted_fileclasses: croak("Invalid file class '%s'"%(fileclass)) # check filename filename_sanity_check(filename) # and host filename_sanity_check(from_host) classdir = get_classdir(fileclass) sourcedir = get_targetdir(classdir, from_host) source = os.path.join(sourcedir, filename) abssource = os.path.abspath(source) dirname = os.path.dirname(abssource) if not dirname in allowed_reads: croak("Host '%s' is not allowed to read from %s"%(host, dirname)) if not os.path.exists(abssource): print "Format: 1" print "Status: 404 not found" info("Not sending %s to remote %s - file does not exist."%(abssource, host)) sys.exit(1) file_size = os.stat(abssource)[stat.ST_SIZE] sha512 = sha512_for_file(abssource) info("Sending %s to remote %s (%s bytes)"%(abssource, host, file_size)) print "Format: 1" print "Status: 200 OK" print "Size: %d"%(file_size) print "SHA-512: %s"%(sha512) print f = open(abssource) while True: data = f.read(block_size) if not data: break sys.stdout.write(data) f.close() parser = optparse.OptionParser() parser.set_usage("%prog [] (local usage)\n" + "via ssh orig command: store-file \n" + " retrieve-file ") parser.add_option("-r", "--read-allow", dest="allowed_reads", metavar="DIR", action="append", help="Allow host to read files in directory.") (options, args) = parser.parse_args() def ensure_args_not_empty(remote_args): if len(remote_args) == 0: croak("One more argument expected.") if len(args) != 1: parser.print_help() sys.exit(1) host = args.pop(0) if not 'SSH_ORIGINAL_COMMAND' in os.environ: print >> sys.stderr, "Did not find SSH_ORIGINAL_COMMAND in environment." sys.exit(1) remote_args = os.environ['SSH_ORIGINAL_COMMAND'].split() ensure_args_not_empty(remote_args) remote_supplied_hostname = remote_args.pop(0) if remote_supplied_hostname != host: croak("Hostname passed from remote does not match locally supplied hostname.") ensure_args_not_empty(remote_args) action = remote_args.pop(0) info("Host %s called with action %s."%(host, action)) if action == "store-file": store_file(host, remote_args) elif action == "retrieve-file": if options.allowed_reads is None: croak("No directories from which read is allowed given on cmdline.") retrieve_file(host, remote_args, options.allowed_reads) else: croak("Invalid operation '%s'"%(action)) # vim:set et: # vim:set ts=4: # vim:set shiftwidth=4: