blob: 3e8808cc3bf3408fea6e3cff0b61c906ab7f27ab [file] [log] [blame]
// Copyright 2010-2015, Google Inc.
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "base/win_api_test_helper.h"
#include <Windows.h>
#include <winnt.h>
#include <type_traits>
#include <map>
#include <memory>
#include <vector>
#include "base/util.h"
namespace mozc {
namespace {
using std::unique_ptr;
typedef WinAPITestHelper::FunctionPointer FunctionPointer;
void WINAPI dummy_func() {}
static_assert(sizeof(&dummy_func) == sizeof(FunctionPointer),
"check function pointer size failed.");
struct Thunk {
FunctionPointer proc;
};
class ThunkRewriter {
public:
ThunkRewriter(const Thunk *thunk, FunctionPointer proc)
: thunk_(thunk),
proc_(proc) {}
bool Rewrite() const {
// Note: There is a race condition between the first VirtualProtect and
// second VirtualProtect.
auto *writable_thunk = const_cast<Thunk *>(thunk_);
DWORD original_protect = 0;
auto result = ::VirtualProtect(
writable_thunk, sizeof(*writable_thunk), PAGE_READWRITE,
&original_protect);
if (result == 0) {
const auto error = ::GetLastError();
LOG(FATAL) << "VirtualProtect failed. error = " << error;
return false;
}
// Here we have write access to the |writable_thunk|.
writable_thunk->proc = proc_;
DWORD dummy = 0;
result = ::VirtualProtect(
writable_thunk, sizeof(*writable_thunk), original_protect, &dummy);
if (result == 0) {
const auto error = ::GetLastError();
LOG(FATAL) << "VirtualProtect failed. error = " << error;
return false;
}
return true;
}
private:
// Represents the memory address of API thunk.
const Thunk *thunk_;
// Represents the true address of API implementation.
FunctionPointer proc_;
};
class HookTargetInfo {
public:
explicit HookTargetInfo(
const vector<WinAPITestHelper::HookRequest> &requests) {
for (size_t i = 0; i < requests.size(); ++i) {
const auto &request = requests[i];
HMODULE module_handle = nullptr;
const auto result = ::GetModuleHandleExA(
GET_MODULE_HANDLE_EX_FLAG_PIN, request.module_name.c_str(),
&module_handle);
if (result == 0) {
const auto error = ::GetLastError();
LOG(FATAL) << "GetModuleHandleExA failed. error = " << error;
continue;
}
const FunctionPointer original_proc_address =
::GetProcAddress(module_handle, request.proc_name.c_str());
if (original_proc_address == nullptr) {
LOG(FATAL) << "GetProcAddress returned nullptr.";
continue;
}
string module_name = request.module_name;
Util::LowerString(&module_name);
info_[module_name][original_proc_address] = request.new_proc_address;
}
}
const bool IsTargetModule(const string &module_name) const {
string lower_module_name(module_name);
Util::LowerString(&lower_module_name);
return info_.find(lower_module_name) != info_.end();
}
const FunctionPointer GetNewProc(
const string &module_name,
FunctionPointer original_proc) const {
string lower_module_name(module_name);
Util::LowerString(&lower_module_name);
const auto module_iterator = info_.find(lower_module_name);
if (module_iterator == info_.end()) {
return nullptr;
}
const auto &proc_map = module_iterator->second;
const auto proc_iterator = proc_map.find(original_proc);
if (proc_iterator == proc_map.end()) {
return nullptr;
}
return proc_iterator->second;
}
private:
map<string, map<FunctionPointer, FunctionPointer>> info_;
};
class PortableExecutableImage {
public:
explicit PortableExecutableImage(HMODULE module_handle)
: module_handle_(module_handle),
is_invalid_image_(false) {
if (module_handle_ == nullptr) {
is_invalid_image_ = true;
return;
}
const auto *dos_header = At<IMAGE_DOS_HEADER>(0);
if (dos_header->e_magic != IMAGE_DOS_SIGNATURE) {
is_invalid_image_ = true;
return;
}
const auto *nt_header = At<IMAGE_NT_HEADERS>(dos_header->e_lfanew);
if (nt_header->Signature != IMAGE_NT_SIGNATURE) {
is_invalid_image_ = true;
return;
}
}
template <typename T>
const T *At(DWORD offset) const {
static_assert(std::is_pod<T>::value, "T should be POD.");
CHECK(!is_invalid_image_);
// TODO(yukawa): Validate if this memory range is safe to be accessed.
return reinterpret_cast<const T *>(
reinterpret_cast<const uint8 *>(module_handle_) + offset);
}
bool IsValid() const {
return !is_invalid_image_;
}
private:
HMODULE module_handle_;
bool is_invalid_image_;
};
class ImageImportDescriptorIterator {
public:
explicit ImageImportDescriptorIterator(const PortableExecutableImage &image)
: image_(image),
import_directory_(nullptr),
descriptor_index_(0),
descriptor_index_max_(0) {
if (!image_.IsValid()) {
return;
}
const auto *dos_header = image_.At<IMAGE_DOS_HEADER>(0);
const auto *nt_header = image_.At<IMAGE_NT_HEADERS>(dos_header->e_lfanew);
import_directory_ =
&nt_header->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT];
descriptor_index_max_ =
import_directory_->Size / sizeof(IMAGE_IMPORT_DESCRIPTOR);
}
const IMAGE_IMPORT_DESCRIPTOR &Get() const {
const auto *desc = GetInternal();
CHECK_NE(0, desc->Name);
return *desc;
}
void Next() {
++descriptor_index_;
}
bool Done() const {
if (!image_.IsValid()) {
return true;
}
if (descriptor_index_ >= descriptor_index_max_) {
return true;
}
if (GetInternal()->Name == 0) {
return true;
}
return false;
}
private:
const IMAGE_IMPORT_DESCRIPTOR *GetInternal() const {
CHECK_LT(descriptor_index_, descriptor_index_max_);
const DWORD import_descriptor_offset =
import_directory_->VirtualAddress +
descriptor_index_ * sizeof(IMAGE_IMPORT_DESCRIPTOR);
return image_.At<IMAGE_IMPORT_DESCRIPTOR>(import_descriptor_offset);
}
const PortableExecutableImage &image_;
const IMAGE_DATA_DIRECTORY *import_directory_;
size_t descriptor_index_;
size_t descriptor_index_max_;
};
class ImageThunkDataIterator {
public:
ImageThunkDataIterator(const PortableExecutableImage &image,
const IMAGE_IMPORT_DESCRIPTOR &import_descriptor)
: image_(image),
import_descriptor_(import_descriptor),
thunk_index_(0) {}
const Thunk *Get() const {
CHECK(!Done());
const auto *raw_thunk = GetInternal();
return reinterpret_cast<const Thunk *>(&raw_thunk->u1.Function);
}
void Next() {
++thunk_index_;
}
bool Done() const {
return GetInternal()->u1.Function == 0;
}
private:
const IMAGE_THUNK_DATA *GetInternal() const {
const DWORD thunk_offset =
import_descriptor_.FirstThunk + thunk_index_ * sizeof(IMAGE_THUNK_DATA);
return image_.At<IMAGE_THUNK_DATA>(thunk_offset);
}
const PortableExecutableImage &image_;
const IMAGE_IMPORT_DESCRIPTOR &import_descriptor_;
size_t thunk_index_;
};
} // namespace
class WinAPITestHelper::RestoreInfo {
public:
vector<ThunkRewriter> rewrites;
};
WinAPITestHelper::HookRequest::HookRequest(
const string &src_module,
const string &src_proc_name,
FunctionPointer new_proc_addr)
: module_name(src_module),
proc_name(src_proc_name),
new_proc_address(new_proc_addr) {}
// static
WinAPITestHelper::RestoreInfoHandle WinAPITestHelper::DoHook(
HMODULE target_module,
const vector<WinAPITestHelper::HookRequest> &requests) {
const HookTargetInfo target_info(requests);
// Following code skips some data validations as this code is only used in
// unit tests.
PortableExecutableImage image(target_module);
CHECK(image.IsValid());
unique_ptr<RestoreInfo> restore_info(new RestoreInfo());
for (ImageImportDescriptorIterator descriptor_iterator(image);
!descriptor_iterator.Done(); descriptor_iterator.Next()) {
const auto &descriptor = descriptor_iterator.Get();
const string module_name(image.At<char>(descriptor.Name));
if (!target_info.IsTargetModule(module_name)) {
continue;
}
for (ImageThunkDataIterator thunk_iterator(image, descriptor);
!thunk_iterator.Done(); thunk_iterator.Next()) {
const auto *thunk = thunk_iterator.Get();
const auto original_proc_address = thunk->proc;
const auto target_proc_address =
target_info.GetNewProc(module_name, original_proc_address);
if (target_proc_address == nullptr) {
continue;
}
// Rewrite rule to do hook.
ThunkRewriter hook_rewriter(thunk, target_proc_address);
// Rewrite rule to restore hook.
ThunkRewriter backup_rewriter(thunk, original_proc_address);
CHECK(hook_rewriter.Rewrite());
restore_info->rewrites.push_back(backup_rewriter);
}
}
return restore_info.release();
}
void WinAPITestHelper::RestoreHook(
WinAPITestHelper::RestoreInfoHandle restore_info) {
unique_ptr<RestoreInfo> info(restore_info); // takes ownership
for (size_t i = 0; i < info->rewrites.size(); ++i) {
const auto &rewrite = info->rewrites[i];
CHECK(rewrite.Rewrite());
}
}
} // namespace mozc