#include "stdafx.h"
#include "ActiveFunctions.h"
#include "Engine.h"
#include "Utils/StackTrace.h"
#include "Utils/Cache.h"

namespace storm {

	/**
	 * Collect stack traces for all UThreads on a single OS thread.
	 */
	class TraceCollector {
	public:
		// Resulting stack traces:
		vector<::StackTrace> traces;

		// Collect stack traces. If 'thisThread' is true, then we also capture this UThread.
		void capture(Bool includeThisThread) {
			vector<os::UThread> stacks = os::Thread::current().idleUThreads();

			if (includeThisThread) {
				captureCurrent();
			}

			for (size_t i = 0; i < stacks.size(); i++) {
				if (!stacks[i].detour(util::memberVoidFn(this, &TraceCollector::captureCurrent))) {
					WARNING(L"Failed to execute detour!");
				}
			}
		}

	private:
		// Capture the current UThread.
		void captureCurrent() {
			traces.push_back(stackTrace());
		}
	};


	PauseThreads::PauseThreads(Engine &engine) : signal(0), wait(0) {
		data = new ActiveFunctions();

		vector<os::Thread> threads = engine.allThreads();
		threadCount = threads.size();

		os::Thread current = os::Thread::current();
		for (size_t i = 0; i < threads.size(); i++) {
			if (threads[i] == current) {
				threadCount--;
			} else {
				os::UThread::spawn(util::memberVoidFn(this, &PauseThreads::threadMain), &threads[i]);
			}
		}

		// Capture the call stacks of us while we are waiting.
		captureStacks(true);

		// Wait for all of them to have entered the thread main function, and have collected stack traces!
		// Note: This halts the user-level scheduler for this thread. Otherwise, the stacks we captured above
		// might no longer be alove when we have finished waiting. We *could* capture our stacks after the
		// other threads are done, but that makes it much harder to write robust reload tests (and can easily
		// be confusing when using the PauseThreads class, as one would not really expect it to allow
		// other UThreads to run when it pauses threads).
		for (size_t i = 0; i < threadCount; i++)
			signal.down();

		// Prepare the data for consumption!
		data->done();
	}

	PauseThreads::~PauseThreads() {
		// Clear data to avoid stale references.
		data->clear();
		data->release();

		// Make sure all operations from this thread are visible to others.
		dataBarrier();

		// Signal all threads to wake up.
		for (size_t i = 0; i < threadCount; i++)
			wait.up();

		// ...and wait for them to wake up. Otherwise, we might destroy 'wait' too early.
		for (size_t i = 0; i < threadCount; i++)
			signal.down();

		// Clear our local ICache as well.
		clearLocalICache();
	}

	void PauseThreads::threadMain() {
		// Capture threads, don't include this thread.
		captureStacks(false);

		// Tell the main thread that we are done.
		signal.up();

		// Start waiting.
		wait.down();

		// Tell the main thread that we are done.
		signal.up();

		// Make any updates visible to this core.
		clearLocalICache();
	}

	void PauseThreads::captureStacks(Bool includeCurrent) {
		TraceCollector collector;
		collector.capture(includeCurrent);

		util::Lock::L z(dataLock);
		data->addThread(collector.traces);
	}


	/**
	 * PauseThreadsData.
	 */

	wostream &operator <<(wostream &to, const ActiveOffset &offset) {
		return to << offset.offset << L" x" << offset.count;
	}


	class FrameCompare {
	public:
		const vector<StackFrame> &frames;

		FrameCompare(const vector<StackFrame> &frames) : frames(frames) {}

		bool operator() (size_t a, size_t b) const {
			return size_t(frames[a].fn()) < size_t(frames[b].fn());
		}
		bool operator() (size_t a, const void *b) const {
			return size_t(frames[a].fn()) < size_t(b);
		}
		bool operator() (const void *a, size_t b) const {
			return size_t(a) < size_t(frames[b].fn());
		}
	};

	ActiveFunctions::ActiveFunctions() : refs(1) {}

	void ActiveFunctions::addRef() {
		atomicIncrement(refs);
	}

	void ActiveFunctions::release() {
		if (atomicDecrement(refs) == 0)
			delete this;
	}

	vector<ActiveOffset> ActiveFunctions::find(const void *function) const {
		size_t fnSize = runtime::codeSize(function);
		const byte *fnStart = (const byte *)function;
		const void *fnEnd = fnStart + fnSize;

		FrameCompare compare(frames);
		vector<size_t>::const_iterator first =
			std::lower_bound(sortedFrames.begin(), sortedFrames.end(), fnStart, compare);
		vector<size_t>::const_iterator last =
			std::upper_bound(sortedFrames.begin(), sortedFrames.end(), fnEnd, compare);

		vector<ActiveOffset> result;
		for (vector<size_t>::const_iterator i = first; i != last; ++i) {
			const StackFrame &frame = frames[*i];
			const byte *stackPos = (const byte *)frame.fn();
			size_t offset = stackPos - fnStart;

			if (result.empty()) {
				result.push_back(ActiveOffset(offset, 1));
			} else if (result.back().offset == offset) {
				result.back().count++;
			} else {
				result.push_back(ActiveOffset(offset, 1));
			}
		}

		return result;
	}

	size_t ActiveFunctions::replace(const void *function, size_t offset, const void *replace, size_t rOffset) const {
		const void *exactPtr = (const byte *)function + offset;
		const void *replacePtr = (const byte *)replace + rOffset;

		FrameCompare compare(frames);
		vector<size_t>::const_iterator found =
			std::lower_bound(sortedFrames.begin(), sortedFrames.end(), exactPtr, compare);

		code::Binary *replaceBinary = code::codeBinary(replace);
		code::Arena *arena = replaceBinary->engine().arena();

		size_t replaced = 0;
		while (!compare(exactPtr, *found)) {
			const StackFrame &frame = frames[*found];
			if (frame.returnLocation) {
				frame.updateReturnLocation(replacePtr);
				// *(const void **)frame.returnLocation = replacePtr;
				replaced++;

				// Update any EH information on the stack:
				if (*found + 1 < frames.size())
					arena->updateEhInfo(replace, rOffset, frames[*found + 1].returnLocation);
			}
			++found;
		}

		return replaced;
	}

	void ActiveFunctions::addThread(const vector<::StackTrace> &src) {
		threadStart.push_back(uthreadStart.size());

		for (size_t i = 0; i < src.size(); i++) {
			uthreadStart.push_back(frames.size());

			const ::StackTrace &st = src[i];
			for (nat j = 0; j < st.count(); j++)
				frames.push_back(st[j]);
		}
	}

	void ActiveFunctions::done() {
		// Note: We know that the function allocations in the GC are pinned at this point, so it
		// makes sense to do this here, even without location dependency objects, etc.

		sortedFrames.reserve(frames.size());
		for (size_t i = 0; i < frames.size(); i++)
			sortedFrames.push_back(i);

		std::sort(sortedFrames.begin(), sortedFrames.end(), FrameCompare(frames));
	}

	void ActiveFunctions::clear() {
		frames.clear();
		uthreadStart.clear();
		threadStart.clear();
	}

}
