#if NET40 || NET461

// Copyright © 2012, 2016 Oracle and/or its affiliates. All rights reserved.
//
// MySQL Connector/NET is licensed under the terms of the GPLv2
// <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most 
// MySQL Connectors. There are special exceptions to the terms and 
// conditions of the GPLv2 as it is applied to this software, see the 
// FLOSS License Exception
// <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
//
// This program is free software; you can redistribute it and/or modify 
// it under the terms of the GNU General Public License as published 
// by the Free Software Foundation; version 2 of the License.
//
// This program is distributed in the hope that it will be useful, but 
// WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY 
// or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 
// for more details.
//
// You should have received a copy of the GNU General Public License along 
// with this program; if not, write to the Free Software Foundation, Inc., 
// 51 Franklin St, Fifth Floor, Boston, MA 02110-1301  USA

using System.IO;
using System;
using Externals.MySql.Data.MySqlClient.Properties;
using Externals.MySql.Data.Common;
using System.Runtime.InteropServices;
using System.Diagnostics;
using System.Security;

namespace Externals.MySql.Data.MySqlClient.Authentication
{
  /// <summary>
  /// 
  /// </summary>
  [SuppressUnmanagedCodeSecurityAttribute()]
  internal class  MySqlWindowsAuthenticationPlugin : MySqlAuthenticationPlugin
  {
    SECURITY_HANDLE outboundCredentials = new SECURITY_HANDLE(0);
    SECURITY_HANDLE clientContext = new SECURITY_HANDLE(0);
    SECURITY_INTEGER lifetime = new SECURITY_INTEGER(0);
    bool continueProcessing;
    string targetName = null;

    protected override void CheckConstraints()
    {
      string platform = String.Empty;
      int p = (int)Environment.OSVersion.Platform;
      if ((p == 4) || (p == 128))
        platform = "Unix";
      else if (Environment.OSVersion.Platform == PlatformID.MacOSX) 
        platform = "Mac OS/X";

      if (!String.IsNullOrEmpty(platform))
        throw new MySqlException(String.Format(Resources.WinAuthNotSupportOnPlatform, platform));

      base.CheckConstraints();
    }

    public override string GetUsername()
    {
      string username = base.GetUsername();
      if (String.IsNullOrEmpty(username)) 
        return "auth_windows";
      return username;
    }

    public override string PluginName
    {
      get { return "authentication_windows_client"; }
    }

    protected override byte[] MoreData(byte[] moreData)
    {
      if (moreData == null)
        AcquireCredentials();

      byte[] clientBlob = null;

      if (continueProcessing)
        InitializeClient(out clientBlob, moreData, out continueProcessing);

      if (!continueProcessing || clientBlob == null || clientBlob.Length == 0)
      {
        FreeCredentialsHandle(ref outboundCredentials);
        DeleteSecurityContext(ref clientContext);
        return null;
      }
      return clientBlob;
    }

    void InitializeClient(out byte[] clientBlob, byte[] serverBlob, out bool continueProcessing)
    {
      clientBlob = null;
      continueProcessing = true;
      SecBufferDesc clientBufferDesc = new SecBufferDesc(MAX_TOKEN_SIZE);
      SECURITY_INTEGER initLifetime = new SECURITY_INTEGER(0);
      int ss = -1;
      try
      {
        uint ContextAttributes = 0;

        if (serverBlob == null)
        {
          ss = InitializeSecurityContext(
              ref outboundCredentials,
              IntPtr.Zero,
              targetName,
              STANDARD_CONTEXT_ATTRIBUTES,
              0,
              SECURITY_NETWORK_DREP,
              IntPtr.Zero, /* always zero first time around */
              0,
              out clientContext,
              out clientBufferDesc,
              out ContextAttributes,
              out initLifetime);

        }
        else
        {
          SecBufferDesc serverBufferDesc = new SecBufferDesc(serverBlob);

          try
          {
            ss = InitializeSecurityContext(ref outboundCredentials,
                ref clientContext,
                targetName,
                STANDARD_CONTEXT_ATTRIBUTES,
                0,
                SECURITY_NETWORK_DREP,
                ref serverBufferDesc,
                0,
                out clientContext,
                out clientBufferDesc,
                out ContextAttributes,
                out initLifetime);
          }
          finally
          {
            serverBufferDesc.Dispose();
          }
        }


        if ((SEC_I_COMPLETE_NEEDED == ss)
            || (SEC_I_COMPLETE_AND_CONTINUE == ss))
        {
          CompleteAuthToken(ref clientContext, ref clientBufferDesc);
        }

        if (ss != SEC_E_OK &&
            ss != SEC_I_CONTINUE_NEEDED &&
            ss != SEC_I_COMPLETE_NEEDED &&
            ss != SEC_I_COMPLETE_AND_CONTINUE)
        {
          throw new MySqlException(
              "InitializeSecurityContext() failed  with errorcode " + ss);
        }

        clientBlob = clientBufferDesc.GetSecBufferByteArray();
      }
      finally
      {
        clientBufferDesc.Dispose();
      }
      continueProcessing = (ss != SEC_E_OK && ss != SEC_I_COMPLETE_NEEDED);
    }
   

    private void AcquireCredentials()
    {

      continueProcessing = true;

      int ss = AcquireCredentialsHandle(null, "Negotiate", SECPKG_CRED_OUTBOUND,
              IntPtr.Zero, IntPtr.Zero, 0, IntPtr.Zero, ref outboundCredentials,
              ref lifetime);
      if (ss != SEC_E_OK)
        throw new MySqlException("AcquireCredentialsHandle failed with errorcode" + ss);
    }

    #region SSPI Constants and Imports

    const int SEC_E_OK = 0;
    const int SEC_I_CONTINUE_NEEDED = 0x90312;
    const int SEC_I_COMPLETE_NEEDED = 0x1013;
    const int SEC_I_COMPLETE_AND_CONTINUE = 0x1014;

    const int SECPKG_CRED_OUTBOUND = 2;
    const int SECURITY_NETWORK_DREP = 0;
    const int SECURITY_NATIVE_DREP = 0x10;
    const int SECPKG_CRED_INBOUND = 1;
    const int MAX_TOKEN_SIZE = 12288;
    const int SECPKG_ATTR_SIZES = 0;
    const int STANDARD_CONTEXT_ATTRIBUTES = 0;

    [DllImport("secur32", CharSet = CharSet.Unicode)]
    static extern int AcquireCredentialsHandle(
        string pszPrincipal,
        string pszPackage,
        int fCredentialUse,
        IntPtr PAuthenticationID,
        IntPtr pAuthData,
        int pGetKeyFn,
        IntPtr pvGetKeyArgument,
        ref SECURITY_HANDLE phCredential,
        ref SECURITY_INTEGER ptsExpiry);

    [DllImport("secur32", CharSet = CharSet.Unicode, SetLastError = true)]
    static extern int InitializeSecurityContext(
        ref SECURITY_HANDLE phCredential,
        IntPtr phContext,
        string pszTargetName,
        int fContextReq,
        int Reserved1,
        int TargetDataRep,
        IntPtr pInput,
        int Reserved2,
        out SECURITY_HANDLE phNewContext,
        out SecBufferDesc pOutput,
        out uint pfContextAttr,
        out SECURITY_INTEGER ptsExpiry);

    [DllImport("secur32", CharSet = CharSet.Unicode, SetLastError = true)]
    static extern int InitializeSecurityContext(
        ref SECURITY_HANDLE phCredential,
        ref SECURITY_HANDLE phContext,
        string pszTargetName,
        int fContextReq,
        int Reserved1,
        int TargetDataRep,
        ref SecBufferDesc SecBufferDesc,
        int Reserved2,
        out SECURITY_HANDLE phNewContext,
        out SecBufferDesc pOutput,
        out uint pfContextAttr,
        out SECURITY_INTEGER ptsExpiry);

    [DllImport("secur32", CharSet = CharSet.Unicode, SetLastError = true)]
    static extern int CompleteAuthToken(
        ref SECURITY_HANDLE phContext,
        ref SecBufferDesc pToken);

    [DllImport("secur32.Dll", CharSet = CharSet.Unicode, SetLastError = false)]
    public static extern int QueryContextAttributes(
        ref SECURITY_HANDLE phContext,
        uint ulAttribute,
        out SecPkgContext_Sizes pContextAttributes);

    [DllImport("secur32.Dll", CharSet = CharSet.Unicode, SetLastError = false)]
    public static extern int FreeCredentialsHandle(ref SECURITY_HANDLE pCred);

    [DllImport("secur32.Dll", CharSet = CharSet.Unicode, SetLastError = false)]
    public static extern int DeleteSecurityContext(ref SECURITY_HANDLE pCred);

    #endregion
  }

  [StructLayout(LayoutKind.Sequential)]
  struct SecBufferDesc : IDisposable
  {

    public int ulVersion;
    public int cBuffers;
    public IntPtr pBuffers; //Point to SecBuffer

    public SecBufferDesc(int bufferSize)
    {
      ulVersion = (int)SecBufferType.SECBUFFER_VERSION;
      cBuffers = 1;
      SecBuffer secBuffer = new SecBuffer(bufferSize);
      pBuffers = Marshal.AllocHGlobal(Marshal.SizeOf(secBuffer));
      Marshal.StructureToPtr(secBuffer, pBuffers, false);
    }

    public SecBufferDesc(byte[] secBufferBytes)
    {
      ulVersion = (int)SecBufferType.SECBUFFER_VERSION;
      cBuffers = 1;
      SecBuffer ThisSecBuffer = new SecBuffer(secBufferBytes);
      pBuffers = Marshal.AllocHGlobal(Marshal.SizeOf(ThisSecBuffer));
      Marshal.StructureToPtr(ThisSecBuffer, pBuffers, false);
    }

    public void Dispose()
    {
      if (pBuffers != IntPtr.Zero)
      {
        Debug.Assert(cBuffers == 1);
        SecBuffer ThisSecBuffer =
            (SecBuffer)Marshal.PtrToStructure(pBuffers, typeof(SecBuffer));
        ThisSecBuffer.Dispose();
        Marshal.FreeHGlobal(pBuffers);
        pBuffers = IntPtr.Zero;
      }
    }

    public byte[] GetSecBufferByteArray()
    {
      byte[] Buffer = null;

      if (pBuffers == IntPtr.Zero)
      {
        throw new InvalidOperationException("Object has already been disposed!!!");
      }
      Debug.Assert(cBuffers == 1);
      SecBuffer secBuffer = (SecBuffer)Marshal.PtrToStructure(pBuffers,
          typeof(SecBuffer));
      if (secBuffer.cbBuffer > 0)
      {
        Buffer = new byte[secBuffer.cbBuffer];
        Marshal.Copy(secBuffer.pvBuffer, Buffer, 0, secBuffer.cbBuffer);
      }
      return (Buffer);
    }

  }

  internal enum SecBufferType
  {
    SECBUFFER_VERSION = 0,
    SECBUFFER_EMPTY = 0,
    SECBUFFER_DATA = 1,
    SECBUFFER_TOKEN = 2
  }

  [StructLayout(LayoutKind.Sequential)]
  internal struct SecHandle //=PCtxtHandle
  {
    IntPtr dwLower; // ULONG_PTR translates to IntPtr not to uint
    IntPtr dwUpper; // this is crucial for 64-Bit Platforms
  }

  [StructLayout(LayoutKind.Sequential)]
    internal struct SecBuffer : IDisposable
  {
    public int cbBuffer;
    public int BufferType;
    public IntPtr pvBuffer;


        internal SecBuffer(int bufferSize)
    {
      cbBuffer = bufferSize;
      BufferType = (int)SecBufferType.SECBUFFER_TOKEN;
      pvBuffer = Marshal.AllocHGlobal(bufferSize);
    }

        internal SecBuffer(byte[] secBufferBytes)
    {
      cbBuffer = secBufferBytes.Length;
      BufferType = (int)SecBufferType.SECBUFFER_TOKEN;
      pvBuffer = Marshal.AllocHGlobal(cbBuffer);
      Marshal.Copy(secBufferBytes, 0, pvBuffer, cbBuffer);
    }

        internal SecBuffer(byte[] secBufferBytes, SecBufferType bufferType)
    {
      cbBuffer = secBufferBytes.Length;
      BufferType = (int)bufferType;
      pvBuffer = Marshal.AllocHGlobal(cbBuffer);
      Marshal.Copy(secBufferBytes, 0, pvBuffer, cbBuffer);
    }

        public void Dispose()
    {
      if (pvBuffer != IntPtr.Zero)
      {
        Marshal.FreeHGlobal(pvBuffer);
        pvBuffer = IntPtr.Zero;
      }
    }
  }

  [StructLayout(LayoutKind.Sequential)]
    internal struct SECURITY_INTEGER
  {
    public uint LowPart;
    public int HighPart;
    public SECURITY_INTEGER(int dummy)
    {
      LowPart = 0;
      HighPart = 0;
    }
  };

  [StructLayout(LayoutKind.Sequential)]
    internal struct SECURITY_HANDLE
  {
    public IntPtr LowPart;
    public IntPtr HighPart;
    public SECURITY_HANDLE(int dummy)
    {
      LowPart = HighPart = new IntPtr(0);
    }
  };

  [StructLayout(LayoutKind.Sequential)]
    internal struct SecPkgContext_Sizes
  {
    public uint cbMaxToken;
    public uint cbMaxSignature;
    public uint cbBlockSize;
    public uint cbSecurityTrailer;
  };

}


#endif