using System;
using System.Collections.Generic;
using System.Net;
using System.Net.Sockets;
using System.Security;
using System.Text;
using System.Threading;

namespace Apewer.Network
{

    /// <summary>TCP 端口代理。</summary>
    public class TcpProxy
    {

        IPEndPoint _local;
        Socket _listen;
        Thread _thread;
        int _backlog;
        int _port;

        Func<IPEndPoint, IPEndPoint> _remote;

        /// <summary>本地端口已连接。</summary>
        public bool Connected { get => _listen == null ? false : _listen.Connected; }

        /// <summary>本地已监听的端口。</summary>
        public int Port { get => _port; }

        /// <summary>监听本地端口以启动代理。</summary>
        /// <param name="local">本地监听端口。</param>
        /// <param name="remote">获取要连接的远程端口,无法获取远程端口时应返回 NULL 值。</param>
        /// <param name="backlog">挂起连接队列的最大长度</param>
        /// <exception cref="ArgumentNullException"></exception>
        /// <exception cref="ArgumentOutOfRangeException"></exception>
        /// <exception cref="SocketException"></exception>
        /// <exception cref="SecurityException"></exception>
        public TcpProxy(IPEndPoint local, Func<IPEndPoint, IPEndPoint> remote, int backlog = 10000)
        {
            Start(local, remote, backlog);
        }

        void Start(IPEndPoint local, Func<IPEndPoint, IPEndPoint> remote, int backlog = 10000)
        {
            if (local == null) throw new ArgumentNullException(nameof(local));
            if (remote == null) throw new ArgumentNullException(nameof(remote));
            if (backlog < 1) throw new ArgumentOutOfRangeException(nameof(backlog));

            _local = local;
            _remote = remote;
            _backlog = backlog;

            var server = new Socket(_local.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
            server.Bind(_local);
            server.Listen(_backlog);
            _listen = server;
            if (server.LocalEndPoint is IPEndPoint lep) _port = lep.Port;

            _thread = new Thread(Listen);
            _thread.IsBackground = true;
            _thread.Start();
        }

        /// <summary>停止代理。</summary>
        public void Stop()
        {
            Close(_listen);
        }

        void Listen()
        {
            while (true)
            {
                if (_listen == null) break;
                Socket socket1 = null;
                IPEndPoint remote = null;
                try
                {
                    socket1 = _listen.Accept();
                    remote = _remote(socket1.RemoteEndPoint as IPEndPoint);
                }
                catch { }
                if (socket1 == null || remote == null) continue;

                Socket socket2 = null;
                try
                {
                    socket2 = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
                    socket2.Connect(remote);
                    ThreadPool.QueueUserWorkItem(Handler, new Socket[] { socket1, socket2 });
                    ThreadPool.QueueUserWorkItem(Handler, new Socket[] { socket2, socket1 });
                }
                catch
                {
                    Close(socket1);
                    Close(socket2);
                }
            }
        }

        void Handler(object obj)
        {
            var tuple = (Socket[])obj;
            var src = tuple[0];
            var dst = tuple[1];
            var buffer = new byte[1024];
            while (true)
            {
                try
                {
                    int count = src.Receive(buffer, buffer.Length, SocketFlags.None);
                    if (count < 1) break;
                    dst.Send(buffer, count, SocketFlags.None);
                }
                catch { break; }
            }
            Close(src);
            Close(dst);
        }

        static void Close(Socket socket)
        {
            if (socket == null) return;
            try { socket.Close(); } catch { }
        }

    }

}