
#include "stdafx.h"
#include "sspiex.h"

#define TEST_MESSAGE "HEY DUDE!  WHAT IS UP!!"



BOOL verbose = TRUE;

DWORD
ConnectToServer(
    LPWSTR server, 
    USHORT port,
    SOCKET *s
    )
{
    DWORD dwError;
    size_t cch;
    char mbserver[MAX_SERVERNAME];
    unsigned long   ulAddress;
    struct hostent *pHost;
    SOCKADDR_IN     sin;

    /* most xp ws32 fns take char* */
    wcstombs_s(
        &cch, 
        mbserver, 
        (size_t)MAX_SERVERNAME, 
        server, 
        (size_t)MAX_SERVERNAME);


    ulAddress = inet_addr(mbserver);

    if (INADDR_NONE == ulAddress) 
    {
        pHost = gethostbyname (mbserver);
        if (NULL == pHost) 
        {
            return ERROR_SERVICE_NOT_FOUND;
        }
        memcpy((char*)&ulAddress, pHost->h_addr, pHost->h_length);
    }

    //--------------------------------------------------------------------
    //  Create the socket.

    *s = socket(
        PF_INET, 
        SOCK_STREAM, 
        0);

    if (INVALID_SOCKET ==  *s) 
        return WSAGetLastError();

    sin.sin_family = AF_INET;
    sin.sin_addr.s_addr = ulAddress;
    sin.sin_port = htons(port);

    //--------------------------------------------------------------------
    //  Connect to the server.

    if(0 == connect(*s, (LPSOCKADDR) &sin, sizeof (sin)))
    {
        printf("connected - %s\n", mbserver);
        return(ERROR_SUCCESS);
    }
        
    dwError = WSAGetLastError();
    printf("connect failed - 0x%x\n", dwError);
    closesocket (*s);
    return dwError;        
}  

SECURITY_STATUS 
ClientAuthenticate(
        SOCKET s,
        LPWSTR target,
        LPWSTR package,
        SEC_WINNT_AUTH_IDENTITY_EX *id,
        ULONG capabilities,
        PCtxtHandle pFinalCtxt
        )
{
    SECURITY_STATUS		status, tmp;
    DWORD               sock_status = SEC_E_OK;
    CredHandle		    hCred;
    SecBuffer		    input;
    SecBuffer		    output = {0};
    SecBufferDesc       sbdIn;
    SecBufferDesc       sbdOut;
    TimeStamp           expiry;
    CtxtHandle          hCtxt;
    BOOL                firstPass = TRUE;
    BOOL                allDone = FALSE;
    ULONG               bread = 0;
    ULONG               attributes = 0;
    
    SecInvalidateHandle(&hCtxt);
    SecInvalidateHandle(&hCred);

    input.cbBuffer = MAX_TOKEN;
    input.BufferType = SECBUFFER_TOKEN;
    input.pvBuffer = LocalAlloc(LPTR, MAX_TOKEN);
    if (!input.pvBuffer)
    {
        status = ERROR_NOT_ENOUGH_MEMORY;
        goto out;
    }

    sbdIn.ulVersion = 0;
    sbdIn.cBuffers = 1;
    sbdIn.pBuffers = &input;

    output.cbBuffer = MAX_TOKEN;
    output.BufferType = SECBUFFER_TOKEN;
    output.pvBuffer = LocalAlloc(LPTR, MAX_TOKEN);
    if (!output.pvBuffer)
    {
        status = ERROR_NOT_ENOUGH_MEMORY;
        goto out;
    }

    sbdOut.ulVersion = 0;
    sbdOut.cBuffers = 1;
    sbdOut.pBuffers = &output;

    status = AcquireCredentialsHandle(
                    NULL,
                    package,
                    SECPKG_CRED_OUTBOUND,
                    NULL,
                    id,
                    NULL,
                    NULL,
                    &hCred,
                    &expiry
                    );

    if (!SEC_SUCCESS(status))
    {
        printf("ACH failed - 0x%x\n", status);
        goto out;
    }
    
    do {

        output.cbBuffer = MAX_TOKEN;
        if (!firstPass)
            DumpSecBuffer("ISC input: ", &input, TRUE);

        status = InitializeSecurityContext(
                    &hCred,
                    (firstPass ? NULL : &hCtxt),
                    target,
                    ISC_REQ_INTEGRITY,
                    0,
                    SECURITY_NATIVE_DREP,
                    (firstPass ? NULL :&sbdIn),
                    0,
                    &hCtxt,
                    &sbdOut,
                    &attributes,
                    &expiry
                    );

        firstPass = FALSE;

        /* all done - no more to send */
        if (output.cbBuffer == 0)
            break;

        if ((status == SEC_I_COMPLETE_NEEDED) ||
            (status == SEC_I_COMPLETE_AND_CONTINUE))  
        {
            tmp = CompleteAuthToken(&hCtxt, &sbdOut);
            if (!SEC_SUCCESS(tmp))  
            {
                printf("complete failed: 0x%08x\n", tmp);
                goto out;
            }

            /* all done - no more to send */
            if (output.cbBuffer == 0)
                break;

        } 
        else if (status != SEC_E_OK &&
            status != SEC_I_CONTINUE_NEEDED)
        {
            /* failure */
            printf("ISC failed - 0x%x\n", status);
            if(SendError(s, status))                  
                printf("send error failure\n");
            goto out;
        }

        printf("ISC: 0x%p:0x%p\n", hCtxt.dwLower, hCtxt.dwUpper);
        DumpSecBuffer("ISC output: ", &output, TRUE);

        /* send it */
        sock_status = SendMsg(s, (PBYTE)output.pvBuffer, output.cbBuffer);
        if (sock_status) 
        {
            printf("send failure - 0x%x\n", sock_status);
            goto out;
        }

        /* final blob sent */
        allDone = (status == SEC_E_OK);  

        /* get response from server - this is input for next pass */
        input.cbBuffer = MAX_TOKEN;
        sock_status = ReceiveMsg(
                    s, 
                    (PBYTE)input.pvBuffer, 
                    input.cbBuffer, 
                    &input.cbBuffer
                    );

        if (sock_status)
        {
            printf("recv failure - 0x%x\n", sock_status);
            goto out;
        }

        /* is this an error ?*/
        if (input.cbBuffer == sizeof(DWORD))
        {
            status = (*((DWORD*)input.pvBuffer));
            if (status)
            {
                printf("server sent error - %x\n", status);
                goto out;
            }
            else
                printf("server authenticated client\n");
                
        }

    } while (!allDone);

    *pFinalCtxt = hCtxt;
    SecInvalidateHandle(&hCtxt);

    

out:
    
    SAFE_FREE(input.pvBuffer);
    SAFE_FREE(output.pvBuffer);
    if (SecIsValidHandle(&hCred)) 
        FreeCredentialsHandle(&hCred);
    if (SecIsValidHandle(&hCtxt))
        DeleteSecurityContext(&hCtxt);

    if (sock_status) 
        status = SEC_E_INTERNAL_ERROR;

    return(status);
}


#define SPN_ARG             L"-spn"
#define SERVER_ARG          L"-server"
#define PORT_ARG            L"-port"
#define PACKAGE_ARG         L"-pkg"
#define USER_ARG            L"-u"
#define DOMAIN_ARG          L"-d"
#define PASSWORD_ARG        L"-p"
#define NOSIGN_ARG          L"-no_sign"
#define NOSEAL_ARG          L"-no_seal"
#define GSS_ARG             L"-gss"
#define REQ_ARG             L"-req"
#define VERBOSE_ARG         L"-v"
#define HELP_ARG            L"-?"

#define GET_NEXT_PARAM(_i_)	if (++i >= argc){usage();return 0;}

void
usage(void)
{
    wprintf(L"Usage: clissp\n");
    wprintf(L"\t%s: <spn> (default host/machinename)\n", SPN_ARG);
    wprintf(L"\t%s: <server> (default local machinename)\n", SERVER_ARG);
    wprintf(L"\t%s: <port> (default 2000)\n", PORT_ARG);
    wprintf(L"\t%s: <package> (default negotiate)\n", PACKAGE_ARG);
    wprintf(L"\t%s: <user> (default logged on user)\n", USER_ARG);
    wprintf(L"\t%s: <domain> (default logged on domain)\n", DOMAIN_ARG);
    wprintf(L"\t%s: <password> (default logged on pwd)\n", PASSWORD_ARG);
    wprintf(L"\t%s: <0xcontext attributes> (default sign/seal)\n", REQ_ARG);
    wprintf(L"\t%s: (use GSSAPI style sign / seal token \n", GSS_ARG);
    wprintf(L"\t%s (turn off signing)\n", NOSIGN_ARG);
    wprintf(L"\t%s (turn off sealing)\n", NOSEAL_ARG);
    wprintf(L"\t%s (verbose logging)\n", VERBOSE_ARG);
    wprintf(L"\t%s (this message)\n", HELP_ARG);
}

int 
wmain(int argc, wchar_t *argv[])
{

    DWORD               dwError;
    DWORD               len;
    SECURITY_STATUS     status;
    SOCKET              s;
    LPWSTR              targetName = NULL;
    WCHAR               spnbuf[MAX_SERVERNAME];
    LPWSTR              server = NULL;
    WCHAR               srvbuf[MAX_SERVERNAME];
    LPWSTR              user = NULL;
    LPWSTR              domain = NULL;
    WCHAR               namebuf[MAX_SERVERNAME];
    LPWSTR              password = NULL;
    LPWSTR              package = L"negotiate";
    int                 port = 2000;
    BOOL                sign = TRUE;
    BOOL                seal = TRUE;
    WSADATA             wsaData;
    CtxtHandle          hCtxt;
    SecBuffer           msg;
    SecBuffer           msgOut;
    ULONG               qop;
    BOOL                gss = FALSE;
    
    /*ULONG               contextReqs = ISC_REQ_MUTUAL_AUTH | 
                            ISC_REQ_CONFIDENTIALITY |
                            ISC_REQ_REPLAY_DETECT |
                            ISC_REQ_SEQUENCE_DETECT |
                            ISC_REQ_CONNECTION |
                            ISC_REQ_INTEGRITY |
                            ISC_REQ_EXTENDED_ERROR;*/
    ULONG               contextReqs =  ISC_REQ_CONFIDENTIALITY;
    ULONG               pkgcount;
    PSecPkgInfo         pkginfo;

   PSEC_WINNT_AUTH_IDENTITY_EXW     id = NULL;
   SecPkgContext_NegotiationInfo    SecPkgNegInfo;

   SecInvalidateHandle(&hCtxt);

    /* parse arguments */
    for (int i = 1; i < argc; i++)  
    {
        if (!_wcsicmp(argv[i], USER_ARG)) 
        {
            GET_NEXT_PARAM(i);
            user = argv[i];
        } 
        else if (!_wcsicmp(argv[i], DOMAIN_ARG)) 
        {
            GET_NEXT_PARAM(i);
            domain = argv[i];
        } 
        else if (!_wcsicmp(argv[i], PASSWORD_ARG)) 
        {
            GET_NEXT_PARAM(i);
            password = argv[i];
        } 
        else if (!_wcsicmp(argv[i], PACKAGE_ARG)) 
        {
            GET_NEXT_PARAM(i);
            package = argv[i];
        }
        else if (!_wcsicmp(argv[i], SERVER_ARG)) 
        {
            GET_NEXT_PARAM(i);
            server = argv[i];
        }
        else if (!_wcsicmp(argv[i], SPN_ARG)) 
        {
            GET_NEXT_PARAM(i);
            targetName = argv[i];
        }
        else if (!_wcsicmp(argv[i], PORT_ARG)) 
        {
            GET_NEXT_PARAM(i);
            port = wcstol(argv[i],NULL,10);
        }
        else if (!_wcsicmp(argv[i], REQ_ARG)) 
        {
            GET_NEXT_PARAM(i);
            contextReqs = wcstol(argv[i],NULL,10);
        }    
        else if (!_wcsicmp(argv[i], NOSIGN_ARG)) 
            sign = FALSE;
        else if (!_wcsicmp(argv[i], NOSEAL_ARG)) 
            seal = FALSE;
        else if (!_wcsicmp(argv[i], VERBOSE_ARG)) 
            verbose = TRUE;
        else if (!_wcsicmp(argv[i], GSS_ARG)) 
            gss = TRUE;
        else {
            usage();
            return ERROR_INVALID_PARAMETER;
        }
    }

    if (!seal)
        contextReqs &= ~( ISC_REQ_INTEGRITY | ISC_REQ_CONFIDENTIALITY);
    else if (!sign)
        contextReqs &= ~( ISC_REQ_INTEGRITY | ISC_REQ_REPLAY_DETECT);

    /* supplied credentials - prevent loopback by providing these */
    if (!user || !domain)
    {
        len = MAX_SERVERNAME;
        if(!GetUserNameEx(NameSamCompatible, namebuf, &len))
        {
            dwError = GetLastError();
            return dwError;
        }

        user = wcsrchr(namebuf, L'\\');
        if (!user)
            return ERROR_INVALID_PARAMETER;

        *user = L'\0';
        user++;
        domain = namebuf;
    }
 
    dwError = BuildAuthIdentity(
                user,
                password,
                domain,
                &id
                );

    if (dwError)
            return dwError;
    
    dwError = WSAStartup(0x0101, &wsaData);
    if (dwError)
    {
        printf("WSAStartup failed - 0x%x\n", dwError);
        return dwError;
    }

    if (!server)
    {
        len = MAX_SERVERNAME;
        /* for domain joined - use ex @todo */
        if (!GetComputerName(srvbuf, &len))
            return GetLastError();
        
        server = srvbuf;
    }

    if (!targetName)
    {
        len = MAX_SERVERNAME;
        swprintf(spnbuf,sizeof(spnbuf) , L"HOST/%s",server);
        targetName = spnbuf;
    }

    dwError = ConnectToServer(server, port, &s);
    if (dwError)
    {
        printf("ConnectToServer failed - 0x%x\n", dwError);
        return dwError;
    }

    dwError = EnumerateSecurityPackages(
                &pkgcount,
                &pkginfo
                );
    if (dwError)
    {
        printf("EnumerateSecurityPackages failed - 0x%x\n", dwError);
        return dwError;
    }

     /* authenticate to server */
    status = ClientAuthenticate(
                s,
                targetName,
                package,
                id,
                contextReqs,
                &hCtxt
                );
   
    if (!SEC_SUCCESS(status))
        return status;

    /* check package which negotiate chose */
    if (!_wcsicmp(package, L"negotiate"))
    {
        status = QueryContextAttributes(
                    &hCtxt,
                    SECPKG_ATTR_NEGOTIATION_INFO,
                    &SecPkgNegInfo );

        if (!SEC_SUCCESS(status))  
            printf("QueryContextAttributes failed - 0x%x\n", status);
        else
            wprintf(L"Package Name: %s\n", SecPkgNegInfo.PackageInfo->Name);
    }
    
    msg.pvBuffer = LocalAlloc(LPTR, MAX_TOKEN);
    msg.cbBuffer = MAX_TOKEN;
    if (!msg.pvBuffer)
        return ERROR_NOT_ENOUGH_MEMORY;

    msgOut.pvBuffer = LocalAlloc(LPTR, MAX_TOKEN);
    msgOut.cbBuffer = MAX_TOKEN;
     if (!msgOut.pvBuffer)
        return ERROR_NOT_ENOUGH_MEMORY;

    /* let's build a message, and send it */
    if (seal) 
    {
        msg.cbBuffer = (ULONG) strlen(TEST_MESSAGE) + 1;
        strcpy_s((char*)msg.pvBuffer, (MAX_TOKEN/sizeof(char)), TEST_MESSAGE);
        msg.BufferType = SECBUFFER_DATA;

        status = EncryptBuffer(
                    &hCtxt,
                    &msg,
                    &msgOut,
                    0,
                    0,
                    gss
                    );
        
        if (!SEC_SUCCESS(status))
            return status;

        dwError = SendMsg(s, (PBYTE)msgOut.pvBuffer, msgOut.cbBuffer);
        if (dwError) 
        {
            printf("send failure - 0x%x\n", dwError);
            return dwError;
        }

        /* get message back from server */
        msg.cbBuffer = MAX_TOKEN;
        dwError = ReceiveMsg(
                    s, 
                    (PBYTE)msg.pvBuffer, 
                    msg.cbBuffer, 
                    &msg.cbBuffer
                    );

        if (dwError)
        {
            printf("recv failure - 0x%x\n", dwError);
            return dwError;
        }

        msgOut.cbBuffer = MAX_TOKEN;
        status = DecryptBuffer(
                    &hCtxt,
                    &msg,
                    &msgOut,
                    &qop,
                    0,
                    gss
                    );

        if (!SEC_SUCCESS(status))
        {
            printf("DecryptBuffer failure - 0x%x\n", status);
            return status;
        }

        printf("msg from server - %s\n", (char*)msgOut.pvBuffer);

    }

    /* reset */
    msg.cbBuffer = MAX_TOKEN;
    msgOut.cbBuffer = MAX_TOKEN;

#if 0
    if (sign)
    {
        msg.cbBuffer = strlen(TEST_MESSAGE) + 1;
        strcpy((char*)msg.pvBuffer, TEST_MESSAGE);
        msg.BufferType = SECBUFFER_DATA;

        status = EncryptBuffer(
                    pCtxt,
                    &msg,
                    &msgOut,
                    0,
                    0
                    );
        
        if (!SEC_SUCCESS(status))
            return status;

        status = SendMsg(s, msgOut.pvBuffer, msgOut.cbBuffer);
        if (!SEC_SUCCESS(status)) 
        {
            printf("send failure - 0x%x\n", status);
            return status;
        }

        /* get message back from server */
        msg.cbBuffer = MAX_TOKEN;
        status = ReceiveMsg(s, msg.pvBuffer, msg.cbBuffer, &read);
        if (!SEC_SUCCESS(status))
        {
            printf("recv failure - 0x%x\n", status);
            return status;
        }

        msgOut.cbBuffer = MAX_TOKEN;
        status = DecryptBuffer(
                    pCtxt,
                    &msg,
                    &msgOut,
                    &qop,
                    0
                    );

        if (!SEC_SUCCESS(status))
        {
            printf("DecryptBuffer failure - 0x%x\n", status);
            return status;
        }

        printf("msg from server - %s\n", (char*)msgOut.pvBuffer);


    }
#endif

    SAFE_FREE(id);
    SAFE_FREE(msgOut.pvBuffer);
    SAFE_FREE(msg.pvBuffer);
    if (SecIsValidHandle(&hCtxt))
        DeleteSecurityContext(&hCtxt);
    shutdown (s, 2);
    closesocket (s);
    WSACleanup ();
    
    return (ERROR_SUCCESS);
}  // end main
