package com.imcode.net.ldap;

import javax.naming.AuthenticationException;
import javax.naming.Context;
import javax.naming.NameNotFoundException;
import javax.naming.NamingException;
import javax.naming.ldap.InitialLdapContext;
import javax.naming.ldap.LdapContext;
import java.util.Date;
import java.util.Hashtable;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;

public class LdapConnectionPool {

    private static AtomicBoolean init = new AtomicBoolean(false);
    private static AtomicLong expiresInMillis = new AtomicLong(0);
    private static BlockingQueue<LdapConnectionImpl> connectionPool;
    private static Semaphore connectionLimiter;
    private static Hashtable<String, String> env;

    public LdapConnectionPool(final String ldapUrl,
                              final String ldapBindDn,
                              final String ldapPassword,
                              final String readTimeout,
                              final int connectionsNumber,
                              final int expiry, TimeUnit expiryTimeUnit) {
        if (!init.get()) {
            init(ldapUrl, ldapBindDn, ldapPassword, readTimeout, connectionsNumber, expiry, expiryTimeUnit);
        }
    }

    public LdapConnection getConnection() throws LdapClientException {
        try {
            connectionLimiter.acquire();
        } catch (InterruptedException e) {
            throw new LdapClientException("", e);
        }
        LdapConnectionImpl connection = connectionPool.poll();
        final boolean connectionIsNull = connection == null;
        if (connectionIsNull || connection.isExpired()) {
            if (!connectionIsNull) {
                connection.closeLdapContext();
            }

            connection = new LdapConnectionImpl(createLdapContext(env), calcExpiry()) {
                @Override
                public void close() throws LdapClientException {
                    try {
                        connectionPool.put(this);
                    } catch (InterruptedException e) {
                        throw new LdapClientException("", e);
                    }
                    connectionLimiter.release();
                }
            };
        }
        return connection;
    }

    private Date calcExpiry() {
        return new Date(System.currentTimeMillis() + expiresInMillis.get());
    }

    private static synchronized void init(String ldapUrl,
                                          String ldapBindDn,
                                          String ldapPassword,
                                          String readTimeout,
                                          int connectionsNumber,
                                          int expiry, TimeUnit expiryTimeUnit) {
        connectionPool = new LinkedBlockingQueue<LdapConnectionImpl>(connectionsNumber);
        connectionLimiter = new Semaphore(connectionsNumber, true);
        env = createLdapJndiEnvironment(ldapUrl, ldapBindDn, ldapPassword, readTimeout);
        expiresInMillis.set(expiryTimeUnit.toMillis(expiry));
        init.set(true);
    }

    static LdapContext createLdapContext(Hashtable<String, String> env) throws LdapClientException {
        try {
            return new InitialLdapContext(env, null);
        } catch (AuthenticationException ex) {
            throw new LdapAuthenticationException("Authentication failed, using login: '" + env.get(Context.SECURITY_PRINCIPAL) + "'", ex);
        } catch (NameNotFoundException ex) {
            throw new LdapClientException("Root not found: " + env.get(Context.PROVIDER_URL), ex);
        } catch (NamingException ex) {
            throw wrapNamingException(env.get(Context.PROVIDER_URL), ex);
        }
    }

    static Hashtable<String, String> createLdapJndiEnvironment(String ldapUrl, String ldapBindDn, String ldapPassword, String readTimeout) {
        Hashtable<String, String> env = new Hashtable<String, String>();

        env.put(Context.INITIAL_CONTEXT_FACTORY, "com.sun.jndi.ldap.LdapCtxFactory");
        env.put("com.sun.jndi.ldap.read.timeout", readTimeout);
        env.put(Context.PROVIDER_URL, ldapUrl);
        env.put(Context.SECURITY_AUTHENTICATION, LdapConnectionImpl.AUTHENTICATION_TYPE_SIMPLE);
        env.put(Context.SECURITY_PRINCIPAL, ldapBindDn);
        env.put(Context.SECURITY_CREDENTIALS, ldapPassword);
        env.put(Context.REFERRAL, "follow");
        // Active directory only: attributes to be returned in binary format
        env.put("java.naming.ldap.attributes.binary", "tokenGroups");

        return env;
    }

    private static LdapClientException wrapNamingException(String ldapUrl, NamingException ex) {
        return new LdapClientException("Failed to create LDAP context " + ldapUrl + ": " + ex.getExplanation(), ex);
    }
}
