Diese Präsentation wurde erfolgreich gemeldet.
Wir verwenden Ihre LinkedIn Profilangaben und Informationen zu Ihren Aktivitäten, um Anzeigen zu personalisieren und Ihnen relevantere Inhalte anzuzeigen. Sie können Ihre Anzeigeneinstellungen jederzeit ändern.

TensorFlow XLA 「XLAとは、から、最近の利用事例について」

2.497 Aufrufe

Veröffentlicht am

2019年2月2日にGoogleにて開催された
「fpgax #11 + TFUG ハード部:DNN専用ハードについて語る会」での発表資料です。
発表概要:
Googleが開発を行っているTensorFlow XLAについて、いったいどんなものかについて解説し、最近の利用事例、Julia Computing、PyTorch+XLAの中でどのような形で利用されているかを紹介します。

Veröffentlicht in: Geräte & Hardware
  • Als Erste(r) kommentieren

TensorFlow XLA 「XLAとは、から、最近の利用事例について」

  1. 1. TensorFlow XLA 「XLAとは、から、最近の利用事例について」 fpgax #11+TensorFlow User Groupハード部 「ディープラーニング専用ハードウェアについてわいわい話す会」 @ Google 作成:2019/12/29, 2019/1/6, 14, 20, 27 Slideshareにて公開 :2019/2/2 @Vengineer
  2. 2. ブログ (2007年~) : Vengineerの戯言  http://blogs.yahoo.co.jp/verification_engineer SlideShare :  https://www.slideshare.net/ssuser479fa3 Twitter (2009年~) : @Vengineer ソースコード解析職人
  3. 3. CQ出版社:雑誌インターフェース に 「TensorFlow XLA および Lite」 に関することを寄稿しました 2017年8月号 2017年9月号 2018年2月号 2018年8月号 2019年1月号 XLA AOT XLA AOT XLA JIT Lite & XLA Lite
  4. 4. 今日お話する内容    TensorFlow r1.13ベースのお話 ・XLA とは? ・XRT とは? XLAの最近の利用事例 ・JuliaでTPUを使う ・PyTorchでTPUを使う コースコード解析ベースなので、コード多いです
  5. 5. XLAとは
  6. 6. TensorFlow XLAとは https://www.tensorflow.org/performance/xla/ XLA(Accelerated Linear Algebra)は、TensorFlow計算を最適化する線形代数のドメ イン固有のコンパイラです。 結果として、サーバーおよびモバイルプラットフォーム での速度、メモリ使用率、移植性が向上します。 当初、ほとんどのユーザーはXLA の大きなメリットは見られませんが、JIT(Just-In-Time)コンパイルや AOT(Ahead-Of-Time)コンパイルを使用してXLAを使用することで実験を開始でき ます。 新しいハードウェアアクセラレータをターゲットとする開発者は、XLAを試すこ とを特にお勧めします。 原文(英語)をそのまま、Google翻訳にお願いしました。
  7. 7. TensorFlow XLAのソースコード r1.0 ~ r1.11 と r1.12 ~ では、違います Slideshareにアップしてある TensroFlow XLA : JIT編 (r1.3版) の内容は古いです
  8. 8. サンプルコードを見てみよう def test_xla_gpu(self): with tf.Session() as sess: x = tf.placeholder(tf.float32, [2], name="x") with tf.device("device:XLA_GPU:0"): y = x * 2 result = sess.run(y, {x: [1.5, 0.5]})
  9. 9. Mul Const Feed(x) Fetch(y) 0)、最初
  10. 10. Mul _Recv Const _Send Feed(x) Fetch(y) 1)、Feed/Fetchノードの追加
  11. 11. Mul _Recv Const _Send cpu : Feed(x) cpu : Fetch(y) XLA_GPU XLA_GPU 2)、Placement
  12. 12. 3)、グラフの分割 _Recv _Recv _Send _Send _Recv _Send XLA_GPU Feed(x) Fetch(y)cpu Mul Const
  13. 13. 4)、XlaLauch Opに変換 XlaLaunch _Recv _Recv _Send _Send _Recv _Send XLA_GPU Feed(x) Fetch(y)cpu
  14. 14. 複数Opsを XlaLaunch Op に変換 XlaLaunch MulConst
  15. 15. TensorFlow XLA : JITでは! 同じデバイス内で実行される Subgraph単位 の ノードをギュギュッと1つにまとめて、 XlaLaunch Op 内で実行する XlaLaunchは、 TensorFlow XLA専用のOpとして実装されている
  16. 16. Passを使ってグラフを変形してるよ compiler/jit/jit_compilation_pass_registration.cc REGISTER_OPTIMIZATIONマクロを使って、 OptimizationPassRegistry::POST_REWRITE_FOR_EXEC Passを追加  ・MarkForCompilationPass // コンパイル可能なものにマーク  ・EncapsulateSubgraphsPass // サブグラフを関数ノード  ・BuildXlaLaunchOpsPass // 関数ノードを_XlaLaunchに置換 上から順番に実行される
  17. 17. xla.compile
  18. 18. ポイント1 : xla.compile (Python API) compiler/xla/g3doc/tutorials/xla_compile.ipynb def build_mnist_model(x, y_): y = tf.keras.layers.Dense(NUM_CLASSES).apply(x) cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y) train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) return y, train_step [y] = xla.compile(build_mnist_model, inputs=[images, labels])
  19. 19. Passを使ってグラフを変形してるよ compiler/jit/jit_compilation_pass_registration.cc REGISTER_OPTIMIZATIONマクロを使って、 OptimizationPassRegistry::PRE_PLACEMENT Passを追加  // EncapsulateXlaComputationsPass rewrites computations generated by the // xla.compile() Python code into XlaLaunch nodes. REGISTER_OPTIMIZATION( OptimizationPassRegistry::PRE_PLACEMENT, 26, EncapsulateXlaComputationsPass); // r1.12 にて導入 // Pythonコードで xla.compile() を実行するとここが呼ばれる
  20. 20. グラフから XlaLauch Op に変換 compiler/jit/encapsulate_xla_computations_pass.cc Status EncapsulateXlaComputationsPass::Run( const GraphOptimizationPassOptions& options) { Encapsulate(options.graph, options.flib_def); BuildXlaLaunchOps(options.graph->get()); return Status::OK(); }
  21. 21. XlaLaunch => XlaCompile + XlaRun - XlaCompile TF function を LocalExecutable にコンパイルする - XlaRun XlaCompile でコンパイルした LocalExecutable を実行する Split XlaLaunch into XlaCompile and XlaRun; NFC , 21 Sep 2018 TensorFlow r1.12にて導入された
  22. 22. Passを使ってグラフを変形してるよ compiler/jit/jit_compilation_pass_registration.cc REGISTER_OPTIMIZATIONマクロを使って、 OptimizationPassRegistry::POST_REWRITE_FOR_EXEC Passを追加  ・MarkForCompilationPass  ・IncreaseDynamismForAutoJitPass (r1.12にて導入)  ・PartiallyDeclusterPass (r1.11にて導入)  ・EncapsulateSubgraphsPass // サブグラフを関数ノード  ・BuildXlaOpsPass // Xla Ops をコンパイル置換
  23. 23. BuildXlaOpsPass::Run compiler/jit/build_xla_ops_pass.cc tatus BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) { Graph* graph = options.graph->get(); …. for (Node* n : xla_compiled_kernels) { // ここで、Xla Ops をコンパイルした後に、実行しています。 TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun( *options.flib_def, lazy_compilation_enabled, graph, n)); } …
  24. 24. ReplaceNodeWithXlaCompileAndXlaRun compiler/jit/build_xla_ops_pass.cc XlaClusterInfo cluster_info; TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info)); ops::_XlaCompile xla_compile(root.WithOpName("xla_compile"), /*constants=*/cluster_info.constant_inputs, /*args=*/cluster_info.non_constant_inputs, /*resources=*/cluster_info.resource_inputs, /*must_compile=*/requires_compilation, cluster_info.function);
  25. 25. REGISTER_OP("_XlaCompile") .Input("constants: Tconstants") .Attr("Tconstants: list(type) >= 0") .Attr("must_compile: bool") .Input("args: Targs") .Attr("Targs: list(type) >= 0") .Input("resources: Nresources * resource") .Attr("Nresources: int >= 0") .Output("key: string") // コンパイルがOKの時の key .Output("compilation_successful: bool") // コンパイルがOK/NG .Attr("function: func") // The compilation cache is stateful. .SetIsStateful() .Doc(R"(XLA Compile Op. For use by the XLA JIT only.
  26. 26. void XlaCompileOp::Compute(OpKernelContext* ctx) { xla::LocalClient* client; const XlaCompiler::CompilationResult* kernel; xla::LocalExecutable* executable; std::map<int, OptionalTensor> variables; … Status status = CompileToLocalExecutable ( ctx, function_, platform_info_, resources_, constants_, /*lazy=*/!must_compile_, &client, &variables, &kernel, &executable); if (must_compile_ || status.code() != error::UNIMPLEMENTED) { OP_REQUIRES_OK(ctx, status); }
  27. 27. // Each execution of an XlaCompile op creates a new XlaExecutableClosure, even // if it didn't have to compile the cluster because of a compilation-cache // hit. This is because we at least need new snapshots of the resource // variables. XlaExecutableClosureStore::KeyT key = XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure( client, executable, kernel, std::move(variables), constants_.size())); Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({})); compilation_key.flat<string>()(0) = key; Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({})); compilation_successful.flat<bool>()(0) = true; ctx->set_output(0, compilation_key) ; ctx->set_output(1, compilation_successful) ; }
  28. 28. static Status CompileToLocalExecutable ( … XlaCompilationCache* cache ; TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>( rm->default_container(), "xla_cache", &cache, [&](XlaCompilationCache** cache) { return BuildCompilationCache(ctx, platform_info, cache); })); ... std::vector<XlaCompiler::Argument> args; TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments( constant_args, *variables, ctx, &args)); return cache->Compile(options, function, args, compile_options, lazy ? XlaCompilationCache::CompileMode::kLazy : XlaCompilationCache::CompileMode::kStrict, kernel, executable);
  29. 29. ReplaceNodeWithXlaCompileAndXlaRun compiler/jit/build_xla_ops_pass.cc std::vector<Output> xla_run_args = cluster_info.non_constant_inputs; absl::c_copy(cluster_info.resource_inputs, std::back_inserter(xla_run_args)); ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args, xla_compile.key, n->output_types()); MoveOutgoingEdges(g, /*old_node=*/n, /*new_node=*/xla_run.operation.node());
  30. 30. REGISTER_OP("_XlaRun") .Input("args: Targs") // 引数 .Attr("Targs: list(type) >= 0") .Output("results: Tresults") .Attr("Tresults: list(type) >= 0") .Input("key: string") // _XlaCompile の結果の key // XLA random-number generation ops are stateful. // TODO(phawkins): create stateful and non-stateful variants of _XlaRun. .SetIsStateful() .Doc(R"(XLA Run Op. For use by the XLA JIT only.
  31. 31. void XlaRunOp::Compute(OpKernelContext* ctx) { Tensor key_tensor = ctx->input(ctx->num_inputs() - 1) ; const XlaExecutableClosureStore::KeyT& key = key_tensor.flat<string>()(0) ; XlaExecutableClosure closure = XlaExecutableClosureStore::Global()->Consume( key); … Env* env = Env::Default(); auto start_time = env->NowMicros(); auto run_result = closure.executable()->Run (launch_context.arguments(), run_options); OP_REQUIRES(ctx, run_result.ok(), run_result.status()); auto elapsed = env->NowMicros() - start_time; VLOG(2) << "Elapsed time in computation: " << elapsed << "us";
  32. 32. ポイント2:XlaCompile + XlaRun XlaCompile XlaRun XlaLaunch~ r1.11 r1.12にて 追加 LocalExecutable生成 実行
  33. 33. Eager Modeでは?
  34. 34. The XLA compile API # xla.compile() doesn't work with Keras model.fit() API or TF eager mode yet.
  35. 35. しかし、 ソースコードは語る
  36. 36. EagerLocalExecute // If we are running a function on explicitly requested TPU, // compile it with XLA. // Note that it is not ideal, but currently ok, to set this // attribute after computing the kernel cache key above. bool compile_with_xla = false; if (op->is_function() && device != nullptr && (device->device_type() == " TPU" || device->device_type() == "XLA_GPU" || device->device_type() == "XLA_CPU")) { op->MutableAttrs()->Set(kXlaCompileAttr, true); compile_with_xla = true; } kXlaCompileAttr を true にすると、 MarkForCompilationPass::Run にて、XLA化の準備をする
  37. 37. Passを使ってグラフを変形してるよ compiler/jit/jit_compilation_pass_registration.cc REGISTER_OPTIMIZATIONマクロを使って、 OptimizationPassRegistry::POST_REWRITE_FOR_EXEC Passを追加  ・MarkForCompilationPass  ・IncreaseDynamismForAutoJitPass (r1.12にて導入)  ・PartiallyDeclusterPass (r1.11にて導入)  ・EncapsulateSubgraphsPass // サブグラフを関数ノード  ・BuildXlaOpsPass // Xla Ops をコンパイル置換
  38. 38. XRTとは? r1.11で導入された
  39. 39. XRT TensorFlow以外からXLAを 利用するための仕組み?
  40. 40. XRTのテストコードを見てみよう xla::XlaComputation AddAndTuple () { xla::XlaBuilder builder("AddAndTuple"); auto p0 = xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(xla::F32, {2}), "P0"); auto p1 = xla::Parameter(&builder, 1, xla::ShapeUtil::MakeShape(xla::F32, {2}), "P1"); auto sum = xla::Add(p0, p1); xla::Tuple(&builder, {sum}); return builder.Build().ValueOrDie(); } sum = xla::Add( p0, p1 )
  41. 41. TEST(RawApiTest, CompileAndExecuteReturnTuple ) { xrt::XLAAllocation p0; p0.set_device_ordinal(0); *p0.mutable_value() = FloatVector({1.0f, 2.0f}); xrt::XLAAllocation p1; p1.set_device_ordinal(0); *p1.mutable_value() = FloatVector({8.0f, 5.0f});
  42. 42. xrt::XLAComputation c; auto config = c.mutable_config(); auto shapes = config->mutable_program_shape(); *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto(); *shapes->mutable_result() = xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})}) .ToProto(); StoreComputationSnapshot( AddAndTuple(), c.mutable_hlo_snapshot()); xrt::XRTExecutionConfig e; e.set_release_input_handles(true); e.set_release_compilation_handle(true);
  43. 43. Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); auto e_config = ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); auto computation = ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); auto c_handle = ops::XRTCompile(root, computation); auto p0_value = ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString()); auto p0_handle = ops::XRTAllocate(root, p0_value); auto p1_value = ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()); auto p1_handle = ops::XRTAllocate(root, p1_value); auto result = ops::XRTExecute(root, c_handle, e_config, {Output(p0_handle), Output(p1_handle)}); auto read_back = ops::XRTReadLiteralAndRelease (root, result); TF_ASSERT_OK(root.status());
  44. 44. ClientSession session(root); std::vector<Tensor> outputs; TF_EXPECT_OK(session.Run({ read_back}, &outputs)); xla::LiteralProto response; EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()())); auto sum = xla::LiteralUtil::CreateR1<float>({9.0f, 7.0f}); auto expected = xla::LiteralUtil::MakeTuple({&sum}); EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); }
  45. 45. XRTCompile XRTExecute XRTReadLiteralAndRelease xrt::XLAComputation xrt::XRTExecutionConfig xrt::XLAAllocation (入力データ) read_back XRTAllocate XRTAllocateConst Const Const Const
  46. 46. REGISTER_OP("XRTAllocate")  ・入力 allocation: string  ・出力 handle: int64
  47. 47. void XRTAllocate::Compute(OpKernelContext* ctx) override { const Tensor& allocation_info = ctx->input(0); // 入力:0 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(allocation_info.shape()), errors::Internal("allocation input should be a string scalar")); xrt::XLAAllocation allocation_proto; OP_REQUIRES( ctx, allocation_proto.ParseFromString(allocation_info.scalar<string>()()), errors::InvalidArgument( "Unable to parse allocation input to XLAAllocation")); xla::Literal literal; OP_REQUIRES_OK( ctx, XRTStateHelpers::MakeLiteral(allocation_proto.value(), &literal));
  48. 48. .... XRTTupleAllocation* allocation; OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer( literal, device_ref.backend(), device_ref.device_ordinal(), &allocation)); // Intern takes ownership of our reference to allocation. int64 key; OP_REQUIRES_OK(ctx, allocation->Intern(rm, &key)); Tensor output(DT_INT64, TensorShape({})); output.scalar<int64>()() = key; ctx->set_output(0, output); // 出力:0
  49. 49. REGISTER_OP("XRTCompile")  ・入力 computation: string  ・出力 handle: int64 program_shape: string
  50. 50. void XRTCompileOp::Compute(OpKernelContext* ctx) { .... const Tensor& computation_input = ctx->input(0); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(computation_input.shape()), errors::Internal("computation input should be a string scalar")); xrt::XLAComputation computation_proto ; OP_REQUIRES( ctx, computation_proto.ParseFromString(computation_input.scalar<string>()() ), errors::InvalidArgument( "Unable to parse computation input to XLAComputation"));
  51. 51. .... int64 uid; OP_REQUIRES_OK( ctx, cache->CompileIfKeyAbsent( key, &uid, [&](std::unique_ptr<xla::LocalExecutable>* program) { VLOG(1) << "Compiling XLA executable"; return Compile(ctx, computation_proto, program); })); std::unique_ptr<XRTCompilationCacheEntryRef> entry; OP_REQUIRES_OK(ctx, cache->Lookup(uid, &entry));
  52. 52. status XRTCompileOp::Compile(OpKernelContext* ctx, const xrt::XLAComputation& computation_proto , std::unique_ptr<xla::LocalExecutable>* program) { .... TF_ASSIGN_OR_RETURN( xla::XlaComputation computation , client->LoadSnapshot( computation_proto.hlo_snapshot ())); .... auto compile_result = client->Compile(computation, argument_layout_ptrs, build_options); if (!compile_result.ok()) { return compile_result.status(); } *program = std::move(compile_result.ValueOrDie()); return Status::OK(); }
  53. 53. .... Tensor handle_output(DT_INT64, TensorShape({})); handle_output.scalar<int64>()() = uid; ctx->set_output(0, handle_output); // 出力:0 xla::LocalExecutable* executable = entry->get().get_executable(); xla::ProgramShapeProto program_shape = executable->executable() ->module() .config() .entry_computation_layout() .ComputeProgramShape() .ToProto(); Tensor program_shape_output(DT_STRING, TensorShape({1})); program_shape_output.vec<string>()(0) = program_shape.SerializeAsString(); ctx->set_output(1, program_shape_output); // 出力:1 }
  54. 54. REGISTER_OP("XRTExecute") .Attr("Ninputs: int >= 0") .Input("computation_handle: int64") .Input("execution_config: string") .Input("input_handles: Ninputs * int64") .Output("output_handle: int64")  ・属性 Ninputs: int >= 0  ・入力 computation_handle: int64 execution_config: string input_handles: Ninputs * int64  ・出力 output_handle: int64
  55. 55. void XRTExecuteOp::ComputeAsync(OpKernelContext* context, DoneCallback done) { // Schedule onto the default queue, for unbounded concurrency. See b/73520706 Env::Default()->SchedClosure([this, context, done]() { OP_REQUIRES_OK_ASYNC(context, DoWork(context), done); done(); }); }
  56. 56. Status XRTExecuteOp::DoWork (OpKernelContext* context) { … const Tensor& execution_input = context->input(0); // 入力:0 TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_input.shape())); int64 compilation_handle = execution_input.scalar<int64>()(); const Tensor& execution_config = context->input(1); // 入力:1 TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape())); xrt::XRTExecutionConfig config_proto; TF_RET_CHECK( config_proto.ParseFromString(execution_config.scalar<string>()()));
  57. 57. … std::vector<xla::ShapedBuffer> input_allocations; std::vector<xla::ShapedBuffer*> input_pointers; TF_RETURN_IF_ERROR( GetComputationInputs (context, rm, release_inputs, &input_tuples, &input_allocations, &input_pointers)); … Status GetComputationInputs (OpKernelContext* context, ResourceMgr* rm, bool release_inputs, std::vector<XRTTupleAllocation*>* input_tuples, std::vector<xla::ShapedBuffer>* input_allocations, std::vector<xla::ShapedBuffer*>* input_pointers) { std::vector<int64> input_uids ; OpInputList arg_list; TF_RETURN_IF_ERROR( context->input_list("input_handles", &arg_list ));
  58. 58. Env* env = Env::Default(); auto start_time = env->NowMicros(); xla::LocalExecutable* executable = entry->get().get_executable(); auto run_result = executable->Run(input_pointers, run_options) ; if (!run_result.ok()) { return run_result.status(); } auto elapsed = env->NowMicros() - start_time; VLOG(2) << "Elapsed time: " << elapsed << "us";
  59. 59. Tensor* output_tensor; TF_RETURN_IF_ERROR( context->allocate_output(0, TensorShape({}), &output_tensor) ); int64 key; TF_RETURN_IF_ERROR(output_tuple->Intern(rm, &key)); output_tensor->scalar<int64>()() = key ;
  60. 60. REGISTER_OP("XRTReadLiteralAndRelease")  ・入力 handle: int64  ・出力 literal: string
  61. 61. void XRTReadLiteralOp::Compute(OpKernelContext* ctx) override { const Tensor& allocation_handle = ctx->input(0) ; // 入力:0 OP_REQUIRES( ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()), errors::Internal("computation input should be an int64 scalar")); int64 allocation_handle = handle_tensor.scalar<int64>()(); … xla::Literal literal; OP_REQUIRES_OK( ctx, allocation->ToLiteral(device_ref.backend(), device_ref.device_ordinal(), &literal)); xla::LiteralProto literal_proto = literal.ToProto(); Tensor output(DT_STRING, TensorShape({})); literal_proto.SerializeToString(&output.scalar<string>()()); ctx->set_output(0, output) ; // 出力:0 }
  62. 62. ポイント3:XRTCompile + XRTExecute XRTReadLiteralA ndRelease XRTAllocate XRTExecute XRTCompile
  63. 63. JuliaでTPUを使う
  64. 64. Automatic Full Compilation of Julia Programs and ML Models to Cloud TPUs https://arxiv.org/abs/1810.09868 Qiita : XLA.jl を試してみた Qiita : JuliaからCloud TPUを使う論文の、ざっくりまとめ
  65. 65. 引用:https://kiszk.github.io/2018/12/19/TensorFlow-Julia-TPU-XLA/
  66. 66. PyTorchでTPUを使う
  67. 67. Introducing PyTorch across Google Cloud , 2018.10.3 https://cloud.google.com/blog/products/ai-machine-learning/introducing-p ytorch-across-google-cloud Today, we’re pleased to announce that engineers on Google’s TPU team are actively collaborating with core PyTorch developers to connect PyTorch to Cloud TPUs. The long-term goal is to enable everyone to enjoy the simplicity and flexibility of PyTorch while benefiting from the performance, scalability, and cost-efficiency of Cloud TPUs.
  68. 68. As a starting point, the engineers involved have produced a prototype that connects PyTorch to Cloud TPUs via XLA, an open source linear algebra compiler. This prototype has successfully enabled us to train a PyTorch implementation of ResNet-50 on a Cloud TPU, and we’re planning to open source the prototype and then expand it in collaboration with the PyTorch community. Please email us at pytorch-tpu@googlegroups.com to tell us what types of PyTorch workloads you would be most interested in accelerating with Cloud TPUs!
  69. 69. PyTorch For TPU : 2018.11.13公開 https://github.com/pytorch/xla PyTorch + XLA のソースコードが公開された:2018.11.16のブログ 利用ケース:  1)、CPUのXLAにて実行する場合。  2)、XRT経由でXLAを利用して、CPUのXLAにて実行する場合、  3)、XRT経由でXLAを利用して、TPUのXLAにて実行する場合 最新コードでは、2) と 3) のみ
  70. 70. テストコード:test/test_train_imagenet.py … import torch_xla_py.xla_model as xm # xla_model をインポート … def train_imagenet(): … model = torchvision.models.resnet50 () # モデル (resnet50) cross_entropy_loss = nn.CrossEntropyLoss() devices = [':{}'.format(n) for n in range(0, FLAGS.num_cores)] inputs = torch.zeros(FLAGS.batch_size, 3, 224, 224) target = torch.zeros(FLAGS.batch_size, dtype=torch.int64) xla_model = xm.XlaModel( # XlaModel にてモデルコンパイル model, [inputs], loss_fn=cross_entropy_loss, target=target, num_cores=FLAGS.num_cores, devices=devices)
  71. 71. … optimizer = optim.SGD( xla_model.parameters_list(), lr=lr, momentum=momentum, weight_decay=5e-4) log_fn = xm.get_log_fn(logdir=FLAGS.logdir) for epoch in range(1, FLAGS.num_epochs + 1): xla_model.train( # xla_model.train にて学習 train_loader, optimizer, FLAGS.batch_size, log_interval=log_interval, metrics_debug=FLAGS.metrics_debug, log_fn=log_fn)
  72. 72. accuracy = xla_model.test( # xla_mode.test にて推論 test_loader, _cross_entropy_loss_eval_fn(cross_entropy_loss), FLAGS.batch_size, log_fn=log_fn) xm.update_optimizer_state(optimizer, 'lr', lambda x: x / 1.025) return accuracy
  73. 73. PyTorch で TPU を使うときのポイント 1)、xla_model をインポート import torch_xla_py.xla_model as xm 2)、XlaModel にてモデルコンパイル xla_model = xm.XlaModel( model, [inputs], … 3)、train にて学習 for epoch in range(1, FLAGS.num_epochs + 1): xla_model.train(...) 4)、test にて推論 accuracy = xla_model.test(...)
  74. 74. def train(self, samples_loader, optimizer, batch_size, log_interval=1, log_fn=print, metrics_debug=False): wloader = LoaderWrapper( samples_loader, self._loader_prefetch, batch_size, num_cores=self._num_cores, devices=self._devices, fused_mode=True) wloader_cleaner = xu.Cleaner(wloader.close) loss = None start_time = time.time() self._epoch += 1
  75. 75. for batch_number, (inputs, targets) in wloader: self._step += 1 optimizer.zero_grad() xla_outputs = xla_run_model(self._xla_model, inputs, devices=self._devices) xla_run_grad(self._xla_model, self._get_backward_grads(xla_outputs), devices=self._devices) optimizer.step() if (log_fn is not None and log_interval is not None and batch_number % log_interval == 0): if metrics_debug: log_fn(torch_xla._XLAC._xla_metrics_report()) loss = self._compute_loss(xla_outputs) log_fn( TrainStepMetrics(self._epoch, self._num_cores, batch_number, len(samples_loader), batch_size, loss, time.time() - start_time, self._step)) return loss
  76. 76. def test(self, samples_loader, eval_fn, batch_size, log_fn=print): wloader = LoaderWrapper( samples_loader, self._loader_prefetch, batch_size, num_cores=self._num_cores, devices=self._devices, fused_mode=True) wloader_cleaner = xu.Cleaner(wloader.close) test_loss = 0 count = 0 correct = 0 start_time = time.time()
  77. 77. with torch.no_grad(): for batch_number, (inputs, targets) in wloader: xla_outputs = xla_run_model(self._xla_model, inputs, devices=self._devices) for i, replica_xla_outputs in enumerate(xla_outputs): output = replica_xla_outputs[1].to_tensor() closs, ccorrect = eval_fn(output, inputs[i][1].to_tensor()) test_loss += closs correct += ccorrect count += batch_size test_loss /= count accuracy = 100.0 * correct / count if log_fn is not None: log_fn( TestStepMetrics(test_loss, correct, count, time.time() - start_time, self._step)) return accuracy
  78. 78. # Run an XLA model with the given tensors. def xla_run_model(xla_model, inputs, devices=None): return xla_model(*convert_to_xla_tensors(inputs, devices=devices)) …  C++側のコード (Pybind11を利用) .def("__call__", [](XlaModule& xla_module , py::args args) -> py::object { auto inputs = XlaCreateTensorList(args); XlaModule::TensorBatchVector outputs; { NoGilSection nogil; outputs = xla_module.forward(inputs); } return XlaPackTensorList(outputs); })
  79. 79. # Runs the backward pass for the given XLA model and the gradient outputs. def xla_run_grad(xla_model, grad_outputs, devices=None): # Trace and symbolically differentiate grads_output_xla = convert_to_xla_tensors(grad_outputs, devices=devices) xla_model.backward (*grads_output_xla) …  C++側のコード (Pybind11を利用) .def("backward", [](XlaModule& xla_module , py::args args) { auto inputs = XlaCreateTensorList(args); NoGilSection nogil; xla_module.backward (inputs); })
  80. 80. torch_xla/csrc/translator.cpp この部分は、2019.1.20に追記 at::aten add、div、sub、mul、gt、type_as、convolution、 thnn_conv2d_forward、thnn_conv2d_backward、t、addmm、mm、max_pool2d_with_indices、 max_pool2d_with_indices_backward、 avg_pool2d、avg_pool2d_backward、 adaptive_avg_pool2d、adaptive_avg_pool2d_backward、 sqrt、rsqrt、neg、tanh、sigmoid、relu、threshold、threshold_backward、 log_softmax、_log_softmax_backward_data、 reshape、view、expand、stack、cat、chunk、 native_batch_norm、batch_norm、native_batch_norm_backward、 sum、nll_loss、nll_loss_backward、size at::prim Constant、Undefined、SumToSize、ListConstruct
  81. 81. 今日のまとめ ・XLA は、r1.12で変わった   ポイント1 : xla.compile (Python API)   ポイント2 : XlaCompile + XlaRun ・XRT が、r1.11 にて導入され ・Julia や PyTorch でTPUで使える   ポイント3 : XRTCompile + XRTExecute
  82. 82. おまけ
  83. 83. まだあります、XLA を利用したもの ・Google/jax : JAX: Autograd and XLA https://github.com/google/jax 現在、頻繁に更新されています!(‘xla’ or ‘xrt’, TPUはまだの模様) ブログ:今週の月曜日(2019.1.28)から金曜日(2019.2.1)まで ・LeFlow : XLA => FPGA https://github.com/danielholanda/LeFlow XLA => LLVM => (LegUp) => Verilog HDL (TensorFlow r1.6ベース、リリース後更新無し?)
  84. 84. あたしは、 ディープラーニング職人 ではありません コンピュータエンジニア です ありがとうございました @Vengineer ソースコード解析職人

×