speed up ssh tarball generation
[mirror/userdir-ldap.git] / ud-generate
index c0b6515..548fd4c 100755 (executable)
@@ -55,7 +55,6 @@ if os.getuid() == 0:
 #
 GroupIDMap = {}
 SubGroupMap = {}
-Allowed = None
 CurrentHost = ""
 
 
@@ -155,12 +154,9 @@ def IsRetired(account):
 #   return account['gidNumber'] == 800
 
 # See if this user is in the group list
-def IsInGroup(account):
-  if Allowed is None:
-     return True
-
+def IsInGroup(account, allowed):
   # See if the primary group is in the list
-  if str(account['gidNumber']) in Allowed: return True
+  if str(account['gidNumber']) in allowed: return True
 
   # Check the host based ACL
   if account.is_allowed_by_hostacl(CurrentHost): return True
@@ -171,7 +167,7 @@ def IsInGroup(account):
   supgroups=[]
   addGroups(supgroups, account['supplementaryGid'], account['uid'])
   for g in supgroups:
-     if Allowed.has_key(g):
+     if allowed.has_key(g):
         return True
   return False
 
@@ -206,8 +202,6 @@ def GenPasswd(accounts, File, HomePrefix, PwdMarker):
       userlist = {}
       i = 0
       for a in accounts:
-         if not IsInGroup(a): continue
-
          # Do not let people try to buffer overflow some busted passwd parser.
          if len(a['gecos']) > 100 or len(a['loginShell']) > 50: continue
 
@@ -265,9 +259,6 @@ def GenShadow(accounts, File):
 
       i = 0
       for a in accounts:
-         Pass = '*'
-         if not IsInGroup(a): continue
-
          # If the account is locked, mark it as such in shadow
          # See Debian Bug #308229 for why we set it to 1 instead of 0
          if not a.pw_active():     ShadowExpire = '1'
@@ -303,8 +294,6 @@ def GenShadowSudo(accounts, File, untrusted):
 
       for a in accounts:
          Pass = '*'
-         if not IsInGroup(a): continue
-     
          if 'sudoPassword' in a:
             for entry in a['sudoPassword']:
                Match = re.compile('^('+UUID_FORMAT+') (confirmed:[0-9a-f]{40}|unconfirmed) ([a-z0-9.,*]+) ([^ ]+)$').match(entry)
@@ -371,7 +360,7 @@ def GenSSHGitolite(accounts, File):
 # Generate the shadow list
 def GenSSHShadow(global_dir, accounts):
    # Fetch all the users
-   userfiles = []
+   userkeys = {}
 
    safe_rmtree(os.path.join(global_dir, 'userkeys'))
    safe_makedirs(os.path.join(global_dir, 'userkeys'))
@@ -379,30 +368,13 @@ def GenSSHShadow(global_dir, accounts):
    for a in accounts:
       if not 'sshRSAAuthKey' in a: continue
 
-      F = None
-      try:
-         OldMask = os.umask(0077)
-         File = os.path.join(global_dir, 'userkeys', a['uid'])
-         F = open(File + ".tmp", "w", 0600)
-         os.umask(OldMask)
-
-         for I in a['sshRSAAuthKey']:
-            MultipleLine = "%s" % I
-            MultipleLine = Sanitize(MultipleLine) + "\n"
-            F.write(MultipleLine)
-
-         Done(File, F, None)
-         userfiles.append(os.path.basename(File))
-
-      # Oops, something unspeakable happened.
-      except IOError:
-         Die(File, F, None)
-         # As neither masterFileName nor masterFile are defined at any point
-         # this will raise a NameError.
-         Die(masterFileName, masterFile, None)
-         raise
-
-   return userfiles
+      contents = []
+      for I in a['sshRSAAuthKey']:
+         MultipleLine = "%s" % I
+         MultipleLine = Sanitize(MultipleLine)
+         contents.append(MultipleLine)
+      userkeys[a['uid']] = contents
+   return userkeys
 
 # Generate the webPassword list
 def GenWebPassword(accounts, File):
@@ -425,12 +397,12 @@ def GenWebPassword(accounts, File):
       Die(File, None, F)
       raise
 
-def GenSSHtarballs(global_dir, userlist, SSHFiles, grouprevmap, target):
+def GenSSHtarballs(global_dir, userlist, ssh_userkeys, grouprevmap, target):
    OldMask = os.umask(0077)
    tf = tarfile.open(name=os.path.join(global_dir, 'ssh-keys-%s.tar.gz' % CurrentHost), mode='w:gz')
    os.umask(OldMask)
-   for f in userlist.keys():
-      if f not in SSHFiles:
+   for f in userlist:
+      if f not in ssh_userkeys:
          continue
       # If we're not exporting their primary group, don't export
       # the key and warn
@@ -451,7 +423,19 @@ def GenSSHtarballs(global_dir, userlist, SSHFiles, grouprevmap, target):
          print "User %s is supposed to have their key exported to host %s but their primary group (gid: %d) isn't in LDAP" % (f, CurrentHost, userlist[f])
          continue
 
-      to = tf.gettarinfo(os.path.join(global_dir, 'userkeys', f), f)
+      contents = ssh_userkeys[f]
+      lines = []
+      for line in contents:
+         if line.startswith("allowed_hosts=") and ' ' in line:
+            machines, line = line.split('=', 1)[1].split(' ', 1)
+            if CurrentHost not in machines.split(','):
+               continue # skip this key
+         lines.append(line)
+      if not lines:
+         continue # no keys for this host
+      contents = "\n".join(lines) + "\n"
+
+      to = tarfile.TarInfo(name=f)
       # These will only be used where the username doesn't
       # exist on the target system for some reason; hence,
       # in those cases, the safest thing is for the file to
@@ -467,19 +451,9 @@ def GenSSHtarballs(global_dir, userlist, SSHFiles, grouprevmap, target):
       to.uname = f
       to.gname = grname
       to.mode  = 0400
-
-      contents = file(os.path.join(global_dir, 'userkeys', f)).read()
-      lines = []
-      for line in contents.splitlines():
-         if line.startswith("allowed_hosts=") and ' ' in line:
-            machines, line = line.split('=', 1)[1].split(' ', 1)
-            if CurrentHost not in machines.split(','):
-               continue # skip this key
-         lines.append(line)
-      if not lines:
-         continue # no keys for this host
-      contents = "\n".join(lines) + "\n"
+      to.mtime = int(time.time())
       to.size = len(contents)
+
       tf.addfile(to, StringIO(contents))
 
    tf.close()
@@ -526,7 +500,6 @@ def GenGroup(accounts, File):
       # Sort them into a list of groups having a set of users
       for a in accounts:
          GroupHasPrimaryMembers[ a['gidNumber'] ] = True
-         if not IsInGroup(a): continue
          if not 'supplementaryGid' in a: continue
 
          supgroups=[]
@@ -568,12 +541,10 @@ def CheckForward(accounts):
    for a in accounts:
       if not 'emailForward' in a: continue
 
-
       delete = False
 
-      if not IsInGroup(a): delete = True
       # Do not allow people to try to buffer overflow busted parsers
-      elif len(a['emailForward']) > 200: delete = True
+      if len(a['emailForward']) > 200: delete = True
       # Check the forwarding address
       elif EmailCheck.match(a['emailForward']) is None: delete = True
 
@@ -1124,7 +1095,7 @@ def generate_all(global_dir, ldap_conn):
    GenAllUsers(accounts, global_dir + 'all-accounts.json')
    accounts = filter(lambda a: not a in accounts_disabled, accounts)
 
-   ssh_files = GenSSHShadow(global_dir, accounts)
+   ssh_userkeys = GenSSHShadow(global_dir, accounts)
    GenMarkers(accounts, global_dir + "markers")
    GenSSHKnown(host_attrs, global_dir + "ssh_known_hosts")
    GenHosts(host_attrs, global_dir + "debianhosts")
@@ -1136,9 +1107,9 @@ def generate_all(global_dir, ldap_conn):
    for host in host_attrs:
       if not "hostname" in host[1]:
          continue
-      generate_host(host, global_dir, accounts, ssh_files)
+      generate_host(host, global_dir, accounts, ssh_userkeys)
 
-def generate_host(host, global_dir, accounts, ssh_files):
+def generate_host(host, global_dir, accounts, ssh_userkeys):
    global CurrentHost
 
    CurrentHost = host[1]['hostname'][0]
@@ -1164,10 +1135,8 @@ def generate_host(host, global_dir, accounts, ssh_files):
       for extra in host[1]['exportOptions']:
          ExtraList[extra.upper()] = True
 
-   global Allowed
-   Allowed = GroupList
-   if Allowed == {}:
-      Allowed = None
+   if GroupList != {}:
+      accounts = filter(lambda x: IsInGroup(x, GroupList), accounts)
 
    DoLink(global_dir, OutDir, "debianhosts")
    DoLink(global_dir, OutDir, "ssh_known_hosts")
@@ -1184,7 +1153,7 @@ def generate_host(host, global_dir, accounts, ssh_files):
 
    # Now we know who we're allowing on the machine, export
    # the relevant ssh keys
-   GenSSHtarballs(global_dir, userlist, ssh_files, grouprevmap, os.path.join(OutDir, 'ssh-keys.tar.gz'))
+   GenSSHtarballs(global_dir, userlist, ssh_userkeys, grouprevmap, os.path.join(OutDir, 'ssh-keys.tar.gz'))
 
    if not 'NOPASSWD' in ExtraList:
       GenShadow(accounts, OutDir + "shadow")
@@ -1201,9 +1170,9 @@ def generate_host(host, global_dir, accounts, ssh_files):
    DoLink(global_dir, OutDir, "mail-rhsbl")
    DoLink(global_dir, OutDir, "mail-whitelist")
    DoLink(global_dir, OutDir, "all-accounts.json")
-   GenCDB(filter(lambda x: IsInGroup(x), accounts), OutDir + "user-forward.cdb", 'emailForward')
-   GenCDB(filter(lambda x: IsInGroup(x), accounts), OutDir + "batv-tokens.cdb", 'bATVToken')
-   GenCDB(filter(lambda x: IsInGroup(x), accounts), OutDir + "default-mail-options.cdb", 'mailDefaultOptions')
+   GenCDB(accounts, OutDir + "user-forward.cdb", 'emailForward')
+   GenCDB(accounts, OutDir + "batv-tokens.cdb", 'bATVToken')
+   GenCDB(accounts, OutDir + "default-mail-options.cdb", 'mailDefaultOptions')
 
    # Compatibility.
    DoLink(global_dir, OutDir, "forward-alias")
@@ -1286,66 +1255,72 @@ def getLastBuildTime():
 
 
 
-parser = optparse.OptionParser()
-parser.add_option("-g", "--generatedir", dest="generatedir", metavar="DIR",
-  help="Output directory.")
-parser.add_option("-f", "--force", dest="force", action="store_true",
-  help="Force generation, even if not update to LDAP has happened.")
+def ud_generate():
+   global GenerateDir
+   global GroupIDMap
+   parser = optparse.OptionParser()
+   parser.add_option("-g", "--generatedir", dest="generatedir", metavar="DIR",
+     help="Output directory.")
+   parser.add_option("-f", "--force", dest="force", action="store_true",
+     help="Force generation, even if not update to LDAP has happened.")
 
-(options, args) = parser.parse_args()
-if len(args) > 0:
-   parser.print_help()
-   sys.exit(1)
+   (options, args) = parser.parse_args()
+   if len(args) > 0:
+      parser.print_help()
+      sys.exit(1)
 
 
-l = make_ldap_conn()
+   l = make_ldap_conn()
 
-if options.generatedir is not None:
-   GenerateDir = os.environ['UD_GENERATEDIR']
-elif 'UD_GENERATEDIR' in os.environ:
-   GenerateDir = os.environ['UD_GENERATEDIR']
+   if options.generatedir is not None:
+      GenerateDir = os.environ['UD_GENERATEDIR']
+   elif 'UD_GENERATEDIR' in os.environ:
+      GenerateDir = os.environ['UD_GENERATEDIR']
 
-ldap_last_mod = getLastLDAPChangeTime(l)
-cache_last_mod = getLastBuildTime()
-need_update = ldap_last_mod > cache_last_mod
+   ldap_last_mod = getLastLDAPChangeTime(l)
+   cache_last_mod = getLastBuildTime()
+   need_update = ldap_last_mod > cache_last_mod
 
-if not options.force and not need_update:
-   fd = open(os.path.join(GenerateDir, "last_update.trace"), "w")
-   fd.write("%s\n%s\n" % (ldap_last_mod, int(time.time())))
-   fd.close()
-   sys.exit(0)
+   if not options.force and not need_update:
+      fd = open(os.path.join(GenerateDir, "last_update.trace"), "w")
+      fd.write("%s\n%s\n" % (ldap_last_mod, int(time.time())))
+      fd.close()
+      sys.exit(0)
 
-# Fetch all the groups
-GroupIDMap = {}
-attrs = l.search_s(BaseDn, ldap.SCOPE_ONELEVEL, "gid=*",\
-                  ["gid", "gidNumber", "subGroup"])
-
-# Generate the SubGroupMap and GroupIDMap
-for x in attrs:
-   if x[1].has_key("accountStatus") and x[1]['accountStatus'] == "disabled":
-      continue
-   if x[1].has_key("gidNumber") == 0:
-      continue
-   GroupIDMap[x[1]["gid"][0]] = int(x[1]["gidNumber"][0])
-   if x[1].has_key("subGroup") != 0:
-      SubGroupMap.setdefault(x[1]["gid"][0], []).extend(x[1]["subGroup"])
-
-lock = None
-try:
-   lockf = os.path.join(GenerateDir, 'ud-generate.lock')
-   lock = get_lock( lockf )
-   if lock is None:
-      sys.stderr.write("Could not acquire lock %s.\n"%(lockf))
-      sys.exit(1)
+   # Fetch all the groups
+   GroupIDMap = {}
+   attrs = l.search_s(BaseDn, ldap.SCOPE_ONELEVEL, "gid=*",\
+                     ["gid", "gidNumber", "subGroup"])
 
-   tracefd = open(os.path.join(GenerateDir, "last_update.trace"), "w")
-   generate_all(GenerateDir, l)
-   tracefd.write("%s\n%s\n" % (ldap_last_mod, int(time.time())))
-   tracefd.close()
+   # Generate the SubGroupMap and GroupIDMap
+   for x in attrs:
+      if x[1].has_key("accountStatus") and x[1]['accountStatus'] == "disabled":
+         continue
+      if x[1].has_key("gidNumber") == 0:
+         continue
+      GroupIDMap[x[1]["gid"][0]] = int(x[1]["gidNumber"][0])
+      if x[1].has_key("subGroup") != 0:
+         SubGroupMap.setdefault(x[1]["gid"][0], []).extend(x[1]["subGroup"])
 
-finally:
-   if lock is not None:
-      lock.release()
+   lock = None
+   try:
+      lockf = os.path.join(GenerateDir, 'ud-generate.lock')
+      lock = get_lock( lockf )
+      if lock is None:
+         sys.stderr.write("Could not acquire lock %s.\n"%(lockf))
+         sys.exit(1)
+
+      tracefd = open(os.path.join(GenerateDir, "last_update.trace"), "w")
+      generate_all(GenerateDir, l)
+      tracefd.write("%s\n%s\n" % (ldap_last_mod, int(time.time())))
+      tracefd.close()
+
+   finally:
+      if lock is not None:
+         lock.release()
+
+if __name__ == "__main__":
+   ud_generate()
 
 
 # vim:set et: