/**
 * $Id: SSHTunnel.java 182695 2022-06-21 19:15:27Z fpina $
 */

package csbase.sshclient;

import java.io.File;
import java.io.IOException;
import java.net.BindException;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.text.MessageFormat;

import net.schmizz.sshj.connection.channel.direct.Parameters;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.connection.channel.direct.LocalPortForwarder;
import net.schmizz.sshj.userauth.keyprovider.OpenSSHKeyFile;

/**
 * SSHTunnel
 *
 * @author Tecgraf
 */
class SSHTunnel implements AutoCloseable {
  /** Default local hostname */
  private static final String LOCALHOST = "localhost";
  /** Logger */
  private final Logger log = LoggerFactory.getLogger(this.getClass());
  /** Local port */
  private int localPort;
  /** The tunnel */
  private final SSHClient sshTunnelClient;
  /** The local socket */
  private final ServerSocket tunnelSocket;

  /**
   * Constructor.
   *
   * @param reachableHost reachable host
   * @param reachablePort reachable host port
   * @param reachableUserName reachable host username
   * @param reachablePrivateKeyFilePath reachable host private key
   * @param unreachableHost unreachable host
   * @param unreachablePort unreachable host port
   * @param localPort local port
   * @param localRange local port range
   *
   * @throws SSHClientException
   */
  public SSHTunnel(final String reachableHost, final int reachablePort,
    final String reachableUserName, final String reachablePrivateKeyFilePath,
    final String unreachableHost, final int unreachablePort,
    final int localPort, final int localRange) throws SSHClientException {

    if (reachableHost == null || reachableHost.isEmpty()) {
      throw new SSHClientException("Reachable host is null or empty");
    }
    if (reachablePort <= 0) {
      throw new SSHClientException("Invalid reachable port");
    }
    if (reachableUserName == null || reachableUserName.isEmpty()) {
      throw new SSHClientException("User name is null or empty");
    }
    if (reachablePrivateKeyFilePath == null
      || reachablePrivateKeyFilePath.isEmpty()) {
      throw new SSHClientException("Private key path is null or empty");
    }

    final File reachablePrivateKey = new File(reachablePrivateKeyFilePath);
    if (!reachablePrivateKey.exists()) {
      throw new SSHClientException("Private key does not exist");
    }

    if (unreachableHost == null || unreachableHost.isEmpty()) {
      throw new SSHClientException("Unreachable Host is null or empty");
    }
    if (unreachablePort <= 0) {
      throw new SSHClientException("Invalid unreachable port");
    }

    if (localPort <= 0) {
      throw new SSHClientException("Invalid local port");
    }

    if (localRange < 0) {
      throw new SSHClientException("Invalid local port range");
    }

    int initLocalPortRange = localPort;
    int endLocalPortRange = localPort + localRange;

    this.sshTunnelClient = new SSHClient();
    SSHUtils.addBlankHostKeyVerifier(this.sshTunnelClient);
    OpenSSHKeyFile keyProv = new OpenSSHKeyFile();
    keyProv.init(reachablePrivateKey);

    try {
      this.sshTunnelClient.connect(reachableHost, reachablePort);
      this.sshTunnelClient.authPublickey(reachableUserName, keyProv);
      tunnelSocket =
        createServerSocket(LOCALHOST, initLocalPortRange, endLocalPortRange);
    }
    catch (Exception e) {
      String errMsg =
        MessageFormat.format(
          "Error while creating tunnel {0}:{1} -> {4}:{5} at local port {6} "
            + "[{2} {3}]", new Object[] { reachableHost, reachablePort,
                reachableUserName, reachablePrivateKeyFilePath, unreachableHost,
                unreachablePort, Integer.toString(localPort) });
      throw new SSHClientException(errMsg, e);
    }

    final Parameters params =
      new Parameters(LOCALHOST, tunnelSocket.getLocalPort(),
        unreachableHost, unreachablePort);

    Thread tunnelThread = new Thread() {
      @Override
      public void run() {
        try {
          String msg =
            MessageFormat.format(
              "Creating tunnel {0}:{1} -> {4}:{5} at local port {6} [{2} {3}]",
              new Object[] { reachableHost, reachablePort, reachableUserName,
                  reachablePrivateKeyFilePath, unreachableHost, unreachablePort,
                  Integer.toString(localPort) });
          log.info(msg);
          sshTunnelClient.newLocalPortForwarder(params, tunnelSocket).listen();
        }
        catch (java.net.SocketException e) {

        }
        catch (IOException e) {
          log.error("Error while forwarding port", e);
        }
        finally {
          try {
            sshTunnelClient.disconnect();
          }
          catch (IOException e) {
            // Do nothing
          }
        }
      }
    };

    tunnelThread.start();
  }

  /**
   * Constructor.
   *
   * @param reachableHost reachable host
   * @param reachablePort reachable host port
   * @param reachableUserName reachable host username
   * @param reachablePrivateKeyFilePath reachable host private key
   * @param unreachableHost unreachable host
   * @param unreachablePort unreachable host port
   * @param localPort local port
   *
   * @throws SSHClientException
   */
  public SSHTunnel(final String reachableHost, final int reachablePort,
    final String reachableUserName, final String reachablePrivateKeyFilePath,
    final String unreachableHost, final int unreachablePort,
    final int localPort) throws SSHClientException {
    this(reachableHost, reachablePort, reachableUserName,
      reachablePrivateKeyFilePath, unreachableHost, unreachablePort, localPort,
      0);
  }

  /**
   * Creates the server socket, binded to the local port in the range specified,
   * that forward messagens to the remote server.
   *
   * @param host local hostname
   * @param initPortRange initial port in the range
   * @param endPortRange end port in the range
   *
   * @return the server socket
   *
   * @throws IOException if there's a error creating the socket
   * @throws SSHClientException if no port in the range are available
   */
  private ServerSocket createServerSocket(String host, int initPortRange,
    int endPortRange) throws IOException, SSHClientException {
    int port = initPortRange;
    while (true) {
      try {
        ServerSocket serverSocket = new ServerSocket();
        serverSocket.setReuseAddress(true);
        serverSocket.bind(new InetSocketAddress(host, port));
        localPort = port;
        return serverSocket;
      }
      catch (BindException e) {
        try {
          port++;
          if (port > endPortRange) {
            String msg;
            if (initPortRange == endPortRange) {
              msg = MessageFormat.format("Port {0} is in use.", initPortRange);
              throw new SSHClientException(msg);
            }
            else {
              msg =
                MessageFormat.format(
                  "All local ports in the range [{0} - {1}] are in use.",
                  initPortRange, endPortRange);
            }
            throw new SSHClientException(msg);
          }
          Thread.sleep(1000);
        }
        catch (InterruptedException e1) {
          // Do nothing
        }
      }
    }
  }

  /**
   * Gets the tunnel's local hostname.
   *
   * @return the local hostname
   */
  public String getLocalhost() {
    return LOCALHOST;
  }

  /**
   * Gets the tunnel's local port.
   *
   * @return the local port
   */
  public int getLocalPort() {
    return localPort;
  }

  /**
   * Checks the tunnel's connection.
   *
   * @return true if the tunnels is connected and false otherwise.
   */
  public boolean isConnected() {
    return sshTunnelClient.isConnected();
  }

  /**
   * {@inheritDoc}
   */
  @Override
  public void close() {
    Thread stopTunnelSocketThread = new Thread() {
      @Override
      public void run() {
        try {
          tunnelSocket.close();
        }
        catch (IOException e) {
        }
      }
    };
    stopTunnelSocketThread.run();
  }
}
