package com.imcode.net.ldap;

import org.apache.commons.collections.map.CaseInsensitiveMap;
import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang.UnhandledException;
import org.apache.log4j.Logger;

import javax.naming.*;
import javax.naming.directory.*;
import javax.naming.ldap.LdapContext;
import javax.naming.ldap.InitialLdapContext;
import javax.naming.ldap.LdapName;
import java.util.*;

public class LdapConnection {

    private static final Logger LOG = Logger.getLogger(LdapConnection.class);

    private static final String AUTHENTICATION_TYPE_SIMPLE = "simple";
    public static final String DISTINGUISHED_NAME = "dn";

    private final String ldapUrl;
    private final String ldapBindDn;
    private final String ldapPassword;

    private volatile LdapContext ctx;

    public LdapConnection(String ldapUrl, String ldapBindDn, String ldapPassword) throws LdapClientException {
        this.ldapUrl = ldapUrl;
        this.ldapBindDn = ldapBindDn;
        this.ldapPassword = ldapPassword;
        connect();
    }

    private synchronized void connect() throws LdapClientException {
        InitialLdapContext result;
        try {
            result = new InitialLdapContext(createLdapJndiEnvironment(ldapUrl, ldapBindDn, ldapPassword), null);
        } catch (AuthenticationException ex) {
            throw new LdapAuthenticationException("Authentication failed, using login: '" + ldapBindDn + "'", ex);
        } catch (NameNotFoundException ex) {
            throw new LdapClientException("Root not found: " + ldapUrl, ex);
        } catch (NamingException ex) {
            throw wrapNamingException(ldapUrl, ex);
        }
        ctx = result;
    }

    private static Hashtable createLdapJndiEnvironment(String ldapUrl, String ldapBindDn, String ldapPassword) {
        Hashtable env = new Hashtable();

        env.put(Context.INITIAL_CONTEXT_FACTORY, "com.sun.jndi.ldap.LdapCtxFactory");
        env.put("com.sun.jndi.ldap.read.timeout", "5000");
        env.put(Context.PROVIDER_URL, ldapUrl);
        env.put(Context.SECURITY_AUTHENTICATION, AUTHENTICATION_TYPE_SIMPLE);
        env.put(Context.SECURITY_PRINCIPAL, ldapBindDn);
        env.put(Context.SECURITY_CREDENTIALS, ldapPassword);
        env.put(Context.REFERRAL, "follow");
        // Active directory only: attributes to be returned in binary format
        env.put("java.naming.ldap.attributes.binary", "tokenGroups");

        return env;
    }


    private Iterator<Map<String, String>> trySearch(String searchFilterExpr, Object[] parameters, SearchControls searchControls) throws NamingException {
        if (null == searchControls) {
            searchControls = new SearchControls();
        }
        searchControls.setReturningObjFlag(true);
        final NamingEnumeration<SearchResult> enumeration = ctx.search("", searchFilterExpr, parameters, searchControls);
        return new SearchResultIterator(enumeration, searchControls);
    }

    private Iterator<Map<String, Set<String>>> trySearchMultivalues(String searchFilterExpr, Object[] parameters, SearchControls searchControls) throws NamingException {
        if (null == searchControls) {
            searchControls = new SearchControls();
        }
        searchControls.setReturningObjFlag(true);
        final NamingEnumeration<SearchResult> enumeration = ctx.search("", searchFilterExpr, parameters, searchControls);
        return new SearchMultivalueResultIterator(enumeration, searchControls);
    }

    public Iterator<Map<String, String>> search(String searchFilterExpr, Object[] parameters,
                                                SearchControls searchControls) throws LdapClientException {
        try {
            try {
                return trySearch(searchFilterExpr, parameters, searchControls);
            } catch (CommunicationException ce) {
                LOG.warn("Problem communicating with LDAP server, retrying.", ce);
                connect();
                return trySearch(searchFilterExpr, parameters, searchControls);
            }
        } catch (NamingException ne) {
            if (isTimedOutNamingException(ne)) {
                for (int i = 1; i <= 5; i++) {
                    LOG.info("LDAP response read timed out, retrying. Attempt " + i);
                    try {
                        connect();
                        return trySearch(searchFilterExpr, parameters, searchControls);
                    } catch (NamingException namingException) {
                        if (!isTimedOutNamingException(namingException)) {
                            throw new LdapClientException("LDAP search failed.", ne);
                        }
                    }
                }
            }
            throw new LdapClientException("LDAP search failed.", ne);
        }
    }

    private boolean isTimedOutNamingException(NamingException ne) {
        return ne.getMessage().startsWith("LDAP response read timed out, timeout used:");
    }

    public Iterator<Map<String, Set<String>>> searchMultivalues(String searchFilterExpr, Object[] parameters,
                                                                SearchControls searchControls) throws LdapClientException {
        try {
            try {
                return trySearchMultivalues(searchFilterExpr, parameters, searchControls);
            } catch (CommunicationException ce) {
                LOG.warn("Problem communicating with LDAP server, retrying.", ce);
                connect();
                return trySearchMultivalues(searchFilterExpr, parameters, searchControls);
            }
        } catch (NamingException ne) {
            throw new LdapClientException("LDAP search failed.", ne);
        }
    }

    public void close() {
        try {
            ctx.close();
        } catch (NamingException ne) {
            LOG.debug("Closing context failed.", ne);
        }
    }

    protected void finalize() throws Throwable {
        super.finalize();
        close();
    }

    private static LdapClientException wrapNamingException(String ldapUrl, NamingException ex) {
        return new LdapClientException("Failed to create LDAP context " + ldapUrl + ": " + ex.getExplanation(), ex);
    }

    private static class SearchResultIterator implements Iterator<Map<String, String>> {

        private final NamingEnumeration<SearchResult> enumeration;
        private final SearchControls searchControls;

        SearchResultIterator(NamingEnumeration<SearchResult> enumeration, SearchControls searchControls) {
            this.enumeration = enumeration;
            this.searchControls = searchControls;
        }

        public boolean hasNext() {
            return enumeration.hasMoreElements();
        }

        public Map<String, String> next() {
            try {
                return createMapFromSearchResult(enumeration.nextElement());
            } catch (NamingException e) {
                throw new UnhandledException(e);
            }
        }

        public void remove() {
            throw new UnsupportedOperationException();
        }

        private Map<String, String> createMapFromSearchResult(SearchResult searchResult) throws NamingException {

            NamingEnumeration attribEnum = searchResult.getAttributes().getAll();
            Map<String, String> attributes = new CaseInsensitiveMap();
            while (attribEnum.hasMoreElements()) {
                Attribute attribute = (Attribute) attribEnum.nextElement();
                String attributeName = attribute.getID();
                String attributeValue = attribute.get().toString();
                attributes.put(attributeName, attributeValue);
            }
            if (!attributes.containsKey(DISTINGUISHED_NAME)) {
                boolean includeDistinguishedName = null != searchControls && searchControls.getReturningObjFlag() && (null == searchControls.getReturningAttributes()
                        || ArrayUtils.contains(searchControls.getReturningAttributes(), DISTINGUISHED_NAME));
                if (includeDistinguishedName) {
                    DirContext dirContext = (DirContext) searchResult.getObject();
                    String distinguishedName = dirContext.getNameInNamespace();
                    attributes.put(DISTINGUISHED_NAME, distinguishedName);
                }
            }

            return attributes;
        }

    }


    private static class SearchMultivalueResultIterator implements Iterator<Map<String, Set<String>>> {

        private final NamingEnumeration<SearchResult> enumeration;
        private final SearchControls searchControls;

        SearchMultivalueResultIterator(NamingEnumeration<SearchResult> enumeration, SearchControls searchControls) {
            this.enumeration = enumeration;
            this.searchControls = searchControls;
        }

        public boolean hasNext() {
            return enumeration.hasMoreElements();
        }

        public Map<String, Set<String>> next() {
            try {
                return createMapFromSearchResult(enumeration.nextElement());
            } catch (NamingException e) {
                throw new UnhandledException(e);
            }
        }

        public void remove() {
            throw new UnsupportedOperationException();
        }

        private Map<String, Set<String>> createMapFromSearchResult(SearchResult searchResult) throws NamingException {

            NamingEnumeration attribEnum = searchResult.getAttributes().getAll();
            Map<String, Set<String>> attributes = new CaseInsensitiveMap();
            while (attribEnum.hasMoreElements()) {
                Attribute attribute = (Attribute) attribEnum.nextElement();
                String attributeName = attribute.getID();
                Enumeration<String> attributeValuesEnum = (Enumeration<String>) attribute.getAll();
                Set<String> attributeValues = new HashSet<String>();

                while (attributeValuesEnum.hasMoreElements()) {
                    attributeValues.add(attributeValuesEnum.nextElement());
                }

                attributes.put(attributeName, Collections.unmodifiableSet(attributeValues));
            }

            return attributes;
        }

    }

    /**
     * @param sAMAccountName
     * @return Active directory user's DN or null if there is no user with provided sAMAAccountName.
     * @throws LdapClientException
     */
    public String getADUserDn(final String sAMAccountName) throws LdapClientException {
        return executeWitReconnect(new LdapCommand<String>() {
            @Override
            public String execute() throws NamingException {
                String searchFilter = "(&(objectClass=user)(sAMAccountName={0}))";
                SearchControls searchCtls = new SearchControls();
                searchCtls.setReturningAttributes(new String[]{});
                searchCtls.setSearchScope(SearchControls.SUBTREE_SCOPE);

                NamingEnumeration<SearchResult> result = ctx.search("", searchFilter, new Object[]{sAMAccountName}, searchCtls);
                String dn = null;

                if (result.hasMoreElements()) {
                    dn = result.nextElement().getNameInNamespace();
                }

                return dn;
            }
        });
    }


    public Set<String> getADUserGroups(final String sAMAccountName) throws LdapClientException {
        String userDn = getADUserDn(sAMAccountName);

        if (userDn == null) return Collections.emptySet();

        Set<String> groups = new HashSet<String>();
        SearchControls userSearchCtls = new SearchControls();

        userSearchCtls.setSearchScope(SearchControls.OBJECT_SCOPE);
        userSearchCtls.setReturningAttributes(new String[]{"tokenGroups"});

        try {
            Name userDnName = new LdapName(userDn);
            Name ctxName = new LdapName(ctx.getNameInNamespace());
            Name name = userDnName.getSuffix(ctxName.size());

            NamingEnumeration userAnswer = ctx.search(name, "(objectClass=user)", userSearchCtls);
            StringBuilder groupsSearchFilter = new StringBuilder();

            groupsSearchFilter.append("(|");

            while (userAnswer.hasMoreElements()) {
                SearchResult sr = (SearchResult) userAnswer.nextElement();
                Attributes attrs = sr.getAttributes();

                if (attrs != null) {
                    for (NamingEnumeration ae = attrs.getAll(); ae.hasMoreElements(); ) {
                        Attribute attr = (Attribute) ae.nextElement();
                        for (NamingEnumeration e = attr.getAll(); e.hasMoreElements(); ) {
                            byte[] sid = (byte[]) e.nextElement();
                            groupsSearchFilter.append("(objectSid=" + binarySidToStringSid(sid) + ")");
                        }
                    }
                }
            }

            groupsSearchFilter.append(")");

            SearchControls groupsSearchCtls = new SearchControls();
            groupsSearchCtls.setSearchScope(SearchControls.SUBTREE_SCOPE);
            groupsSearchCtls.setReturningAttributes(new String[]{"sAMAccountName"});

            NamingEnumeration groupsAnswer = ctx.search("", groupsSearchFilter.toString(), groupsSearchCtls);

            while (groupsAnswer.hasMoreElements()) {
                SearchResult sr = (SearchResult) groupsAnswer.nextElement();
                Attributes attrs = sr.getAttributes();

                if (attrs != null) {
                    groups.add(attrs.get("sAMAccountName").get().toString());
                }
            }
        } catch (NamingException e) {
            throw new LdapClientException("", e);
        }

        return groups;
    }


    public static String binarySidToStringSid(byte[] SID) {
        String strSID = "";

        long version;
        long authority;
        long count;
        long rid;

        strSID = "S";
        version = SID[0];
        strSID = strSID + "-" + Long.toString(version);
        authority = SID[4];

        for (int i = 0; i < 4; i++) {
            authority <<= 8;
            authority += SID[4 + i] & 0xFF;
        }

        strSID = strSID + "-" + Long.toString(authority);
        count = SID[2];
        count <<= 8;
        count += SID[1] & 0xFF;

        for (int j = 0; j < count; j++) {
            rid = SID[11 + (j * 4)] & 0xFF;

            for (int k = 1; k < 4; k++) {
                rid <<= 8;
                rid += SID[11 - k + (j * 4)] & 0xFF;
            }

            strSID = strSID + "-" + Long.toString(rid);
        }

        return strSID;
    }

    private interface LdapCommand<T> {
        T execute() throws NamingException;
    }


    private <T> T executeWitReconnect(LdapCommand<T> ldapCommand) throws LdapClientException {
        return executeWitReconnect(ldapCommand, 2);
    }

    private <T> T executeWitReconnect(LdapCommand<T> ldapCommand, int retryCount) throws LdapClientException {
        T result;

        try {
            result = ldapCommand.execute();
        } catch (CommunicationException e) {
            if (retryCount > 1) {
                connect();
                result = executeWitReconnect(ldapCommand, retryCount - 1);
            } else {
                throw new LdapClientException("", e);
            }
        } catch (NamingException e) {
            throw new LdapClientException("", e);
        }

        return result;
    }
}
