Created
June 30, 2018 15:52
-
-
Save ntrrgc/859476b1670b196f5e0606f092276da3 to your computer and use it in GitHub Desktop.
Proof of concept for page-based watchpoints
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#define __USE_POSIX199309 | |
#include <cassert> | |
#include <signal.h> | |
#include <mutex> | |
#include <functional> | |
#include <sys/ptrace.h> | |
#include <malloc.h> | |
#include <sys/mman.h> | |
#include <unistd.h> | |
#include <set> | |
#include <functional> | |
using namespace std; | |
const size_t PAGESIZE = sysconf(_SC_PAGESIZE); | |
class PointerRange { | |
public: | |
void* start; | |
void* end; | |
std::function<void()> onAccess; | |
bool operator<(const PointerRange& other) const { | |
return start < other.start; | |
} | |
size_t size() const { | |
return (char*)end - (char*)start; | |
} | |
}; | |
class PointerRangeList: public std::set<PointerRange> { | |
public: | |
iterator findContainingPointer(void* pointer) { | |
for (iterator i = begin(); i != end(); ++i) { | |
const PointerRange& range = *i; | |
if (range.start <= pointer && pointer < range.end) | |
return i; | |
} | |
return end(); | |
} | |
bool contains(void* pointer) const { | |
for (const PointerRange& range : *this) { | |
if (range.start <= pointer && pointer < range.end) | |
return true; | |
} | |
return false; | |
} | |
}; | |
class MemoryUsageWatcher { | |
public: | |
static void initialize() { | |
assert(!s_instance); | |
s_instance = new MemoryUsageWatcher(); | |
} | |
static MemoryUsageWatcher& instance() { | |
assert(s_instance); | |
return *s_instance; | |
} | |
// Called only from patrol thread. | |
void watchRange(void* _start, size_t size, function<void()> onAccess) { | |
char* start = (char*) _start; | |
// TODO ensure page-aligned | |
{ | |
lock_guard<mutex> lock(mutex); | |
// No intersections: | |
assert(!m_watchedPages.contains(start)); | |
assert(!m_watchedPages.contains(start + size)); | |
PointerRange allocationRange { start, start + size, onAccess }; | |
m_watchedPages.insert(allocationRange); | |
if (0 != mprotect(start, size, PROT_NONE)) { | |
perror("watchRange: "); | |
abort(); | |
} | |
} | |
} | |
// Called only from patrol thread. | |
void removeWatch(void* start) { | |
lock_guard<mutex> lock(mutex); | |
PointerRangeList::iterator rangeIter = m_watchedPages.findContainingPointer(start); | |
if (rangeIter != m_watchedPages.end()) { | |
if (0 != mprotect(rangeIter->start, rangeIter->size(), PROT_READ | PROT_WRITE)) { | |
perror("removeWatch: "); | |
abort(); | |
} | |
m_watchedPages.erase(rangeIter); | |
} | |
} | |
private: | |
static MemoryUsageWatcher* s_instance; | |
mutex m_mutex; | |
PointerRangeList m_watchedPages; | |
struct sigaction oldSigAction; | |
MemoryUsageWatcher() { | |
struct sigaction newSigAction; | |
newSigAction.sa_sigaction = MemoryUsageWatcher::segfaultHandlerWrapper; | |
newSigAction.sa_flags = SA_SIGINFO | SA_NODEFER; | |
// If our segfault handler has a bug, we want to catch it as usual, | |
// but otherwise we want no signals to interrupt the signal handler. | |
sigfillset(&newSigAction.sa_mask); | |
sigdelset(&newSigAction.sa_mask, SIGILL); | |
sigdelset(&newSigAction.sa_mask, SIGBUS); | |
sigdelset(&newSigAction.sa_mask, SIGFPE); | |
sigdelset(&newSigAction.sa_mask, SIGSEGV); | |
sigdelset(&newSigAction.sa_mask, SIGPIPE); | |
sigdelset(&newSigAction.sa_mask, SIGSTKFLT); | |
if (0 != sigaction(SIGSEGV, &newSigAction, &oldSigAction)) { | |
perror("MemoryUsageWatcher: could not set up signal handler: "); | |
abort(); | |
} | |
} | |
public: | |
void segfaultHandler(void* accessedAddress) { | |
static thread_local bool insideSegfaultHandler = false; | |
static thread_local void* accessedAddressParentHandler = nullptr; | |
if (insideSegfaultHandler) { | |
// Segfault on segfault, this either is caused by: | |
if (accessedAddress != accessedAddressParentHandler) { | |
// a) A bug in this signal handler, who accessed an invalid pointer accidentally. | |
static const char msg[] = "MemoryUsageWatcher: Internal segmentation fault\n"; | |
write(STDERR_FILENO, msg, sizeof(msg)); | |
} else { | |
// b) A bug in the application, who accessed an invalid pointer accidentally, | |
// reaching this handler, who after ensuring that the pointer was not covered | |
// but a watched, mprotect()'ed range, decided to access it to check whether | |
// it was because the region was already unprotected by another thread that | |
// got the lock just before us or it was in fact invalid memory and turned out | |
// to be the latter. | |
} | |
// Either way, we have to abort for real. | |
sigaction(SIGSEGV, &oldSigAction, nullptr); | |
raise(SIGSEGV); | |
} | |
insideSegfaultHandler = true; | |
accessedAddressParentHandler = accessedAddress; | |
// TODO Check if we inside of real malloc in this thread. If that's the | |
// case, abort immediately before more damage is done (e.g. by the code | |
// following, which may use malloc()/free(). | |
// The watched pages table can't be read and modified at the same time. | |
// Also, if two threads access the same page simultaneously, this | |
// ensures that only one executes the `onAccess()` callback. | |
lock_guard<mutex> lock(mutex); | |
PointerRangeList::iterator rangeIter = m_watchedPages.findContainingPointer(accessedAddress); | |
if (rangeIter != m_watchedPages.end()) { | |
if (0 != mprotect(rangeIter->start, rangeIter->size(), PROT_READ | PROT_WRITE)) { | |
perror("segfaultHandler: "); | |
abort(); | |
} | |
rangeIter->onAccess(); | |
m_watchedPages.erase(rangeIter); | |
} else { | |
// The address is not in the table of watched ranges. It may have | |
// been removed by another thread that acquired the lock before us, | |
// or maybe it's just a buggy pointer from the application. | |
// How can we know? Just access the pointer. In the former case, it | |
// will do nothing, in the latter, it will segfault again. | |
volatile char *pointer = (char*) accessedAddress; | |
*pointer; | |
} | |
insideSegfaultHandler = false; | |
} | |
static void segfaultHandlerWrapper(int signum, siginfo_t* siginfo, void*) { | |
instance().segfaultHandler(siginfo->si_addr); | |
} | |
}; | |
MemoryUsageWatcher* MemoryUsageWatcher::s_instance = nullptr; | |
struct Potato { | |
int x = 10; | |
char y = 1; | |
char z = 5; | |
}; | |
int main(int argc, char** argv) { | |
printf("Page size: %zu\n", PAGESIZE); | |
MemoryUsageWatcher::initialize(); | |
Potato* p = (Potato*) pvalloc(sizeof(Potato)); | |
printf("Allocated %p\n", p); | |
new(p) Potato; | |
MemoryUsageWatcher::instance().watchRange(p, PAGESIZE, []() { | |
printf("Potato access detected!\n"); | |
}); | |
printf("p->y = %d\n", p->y); | |
printf("p->x = %d\n", p->x); | |
MemoryUsageWatcher::instance().watchRange(p, PAGESIZE, []() { | |
printf("Second potato access detected!\n"); | |
}); | |
printf("p->z = %d\n", p->z); | |
printf("p->y = %d\n", p->y); | |
p->~Potato(); | |
free(p); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment