Created
July 16, 2017 19:27
-
-
Save bartvm/4da4835ec21a12e4e8d657efeb1e1f04 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp | |
index f1e09b0e..c50f507e 100644 | |
--- a/torch/csrc/autograd/engine.cpp | |
+++ b/torch/csrc/autograd/engine.cpp | |
@@ -21,6 +21,10 @@ | |
#include <THC/THC.h> | |
#endif | |
+void tid() { | |
+ printf("%d ", (int)std::hash<std::thread::id>()(std::this_thread::get_id())); | |
+} | |
+ | |
using thpp::Tensor; | |
namespace torch { namespace autograd { | |
@@ -89,14 +93,21 @@ auto ReadyQueue::push_front(FunctionTask item) -> void { | |
{ | |
std::lock_guard<std::mutex> lock(mutex); | |
++item.base->outstanding_tasks; | |
+ tid(); | |
+ printf("Pushed task onto queue, %llu outstanding\n", item.base->outstanding_tasks.load()); | |
queue.push_front(std::move(item)); | |
} | |
not_empty.notify_one(); | |
} | |
auto ReadyQueue::pop_back() -> FunctionTask { | |
+ tid(); | |
+ printf("Getting lock\n"); | |
std::unique_lock<std::mutex> lock(mutex); | |
- not_empty.wait(lock, [this]{ return !queue.empty(); }); | |
+ printf("Waiting for a task\n"); | |
+ if (queue.empty()) { | |
+ not_empty.wait(lock, [this]{ return !queue.empty(); }); | |
+ } | |
auto task = std::move(queue.back()); queue.pop_back(); | |
return task; | |
} | |
@@ -110,20 +121,35 @@ Engine::~Engine() = default; | |
auto Engine::thread_main(std::shared_ptr<ReadyQueue> queue, int device) -> void { | |
THInferNumThreads(); | |
AutoGPU guard(device); | |
+ tid(); | |
+ printf("Starting endless loop main thread\n"); | |
while (1) { | |
+ tid(); | |
+ printf("Trying to get next task\n"); | |
FunctionTask task = queue->pop_back(); | |
+ tid(); | |
+ printf("Got a task\n"); | |
if (!task.base->has_error.load()) { | |
try { | |
+ tid(); | |
+ printf("About to evaluate function, %llu outstanding\n", task.base->outstanding_tasks.load()); | |
evaluate_function(task); | |
} catch (std::exception& e) { | |
thread_on_exception(task, e); | |
} | |
} | |
+ tid(); | |
+ printf("Evaluated function, %llu outstanding\n", task.base->outstanding_tasks.load() - 1); | |
if (--task.base->outstanding_tasks == 0) { | |
std::lock_guard<std::mutex> lock(task.base->mutex); | |
- task.base->not_done.notify_all(); | |
+ task.base->not_done.notify_one(); | |
+ tid(); | |
+ printf("Breaking free!\n"); | |
+ break; | |
} | |
} | |
+ tid(); | |
+ printf("Ending main thread\n"); | |
} | |
auto Engine::thread_on_exception(FunctionTask& task, std::exception& e) -> void { | |
@@ -299,7 +325,9 @@ auto Engine::execute(const function_list& input_roots, | |
variable_list& inputs, | |
bool keep_graph, | |
const callback_map& callbacks) -> void { | |
- std::call_once(start_threads_flag, &Engine::start_threads, this); | |
+ tid(); | |
+ printf("Engine starting threads\n"); | |
+ start_threads(); | |
// Callbacks are only valid for the duration of this run and should always be cleared | |
ClearCallbacks _cb_guard(post_callbacks, post_callbacks_lock); | |
@@ -310,6 +338,8 @@ auto Engine::execute(const function_list& input_roots, | |
function_queue roots; | |
for (auto entry : input_roots) { | |
if (entry.first->is_executable) { | |
+ tid(); | |
+ printf("Pushed first task to queue\n"); | |
graph_task.has_any_work = true; | |
roots.push_back(graph_root.get()); | |
ready_queue(-1).push_front(FunctionTask(&graph_task, graph_root, InputBuffer(0))); | |
@@ -329,9 +359,15 @@ auto Engine::execute(const function_list& input_roots, | |
compute_dependencies(std::move(roots), graph_task); | |
// Wait for all tasks to complete | |
- graph_task.not_done.wait(lock, [&graph_task]{ | |
- return graph_task.outstanding_tasks.load() == 0; | |
- }); | |
+ tid(); | |
+ printf("Waiting for graph to complete!\n"); | |
+ if (graph_task.outstanding_tasks.load() != 0) { | |
+ graph_task.not_done.wait(lock, [&graph_task]{ | |
+ return graph_task.outstanding_tasks.load() == 0; | |
+ }); | |
+ } | |
+ tid(); | |
+ printf("Done waiting\n"); | |
// Check for an exception while running backwards | |
if (graph_task.has_error.load()) { | |
@@ -372,6 +408,8 @@ auto Engine::start_threads() -> void { | |
} | |
#endif | |
int num_threads = num_devices + 1; | |
+ tid(); | |
+ printf("Starting %d threads\n", num_threads); | |
ready_queues = std::vector<std::shared_ptr<ReadyQueue>>(num_threads); | |
for (int i = 0; i < num_threads; ++i) { | |
auto& queue = ready_queues[i]; | |
diff --git a/torch/csrc/autograd/engine.h b/torch/csrc/autograd/engine.h | |
index a0308f7d..66193b96 100644 | |
--- a/torch/csrc/autograd/engine.h | |
+++ b/torch/csrc/autograd/engine.h | |
@@ -55,7 +55,6 @@ protected: | |
virtual void thread_main(std::shared_ptr<ReadyQueue> queue, int device); | |
virtual void thread_on_exception(FunctionTask& task, std::exception& e); | |
- std::once_flag start_threads_flag; | |
std::vector<std::shared_ptr<ReadyQueue>> ready_queues; | |
std::vector<std::function<void()>> post_callbacks; | |
std::mutex post_callbacks_lock; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment