كيفية استدعاء نموذج TensorFlow المدربين من برامج Java

اللغة الأساسية التي يتم فيها إنشاء وتدريب نماذج TensorFlow للتعلم الآلي هي Python. ومع ذلك ، تتم كتابة العديد من برامج خادم المؤسسات بلغة Java. لذا ، غالبًا ما تصادف مواقف تحتاج فيها إلى استدعاء نموذج Tensorflow الذي قمت بتدريبه في Python من برنامج Java.

إذا كنت تستخدم Cloud ML Engine على Google Cloud Platform ، فهذه ليست مشكلة - في Cloud ML Engine ، يتم إجراء التنبؤات من خلال مكالمة REST API وبالتالي يمكنك القيام بذلك من أي لغة برمجة. ولكن ماذا لو قمت بتنزيل نموذج TensorFlow ، وترغب في تنفيذ التوقعات دون اتصال بالإنترنت؟

إليك كيفية إجراء تنبؤات في Java باستخدام نماذج Tensorflow التي تم تدريبها في Python.

ملاحظة: بدأ فريق Tensorflow الآن في إضافة روابط Java. انظر https://github.com/tensorflow/tensorflow/tree/master/tensorflow/java للحصول على التفاصيل. جرب ذلك أولاً ، وإذا لم ينجح ذلك ، تعال إلى هنا ...

اكتب ملفات النماذج في Python

أول شيء يجب فعله هو حفظ نموذج TensorFlow في Python بتنسيقين: (أ) الأوزان والتحيزات وما إلى ذلك كملف "saver_def" (ب) الرسم البياني نفسه كملف protobuf. للحفاظ على سلامة عقلك ، قد ترغب في حفظ الرسم البياني كنص وتنسيق protobuf ثنائي. ستجد أنه من المفيد قراءة تنسيق النص للعثور على الأسماء التي تم تعيينها بواسطة TensorFlow إلى العقد التي لم تقم بتعيين أسماء لها صراحة. رمز كتابة هذه الملفات الثلاثة من Python:

# إنشاء كائن Saver كالمعتاد في Python لحفظ المتغيرات الخاصة بك
التوقف = tf.train.Saver (...)
# استخدم saver_def للحصول على السلاسل "السحرية" لاستعادتها
saver_def = saver.as_saver_def ()
طباعة saver_def.filename_tensor_name
طباعة saver_def.restore_op_name
# اكتب 3 ملفات
saver.save (sess ، "المدربين". sd ')
tf.train.write_graph (sess.graph_def، '.'، 'المدربين_مودل.بروتو'، as_text = خطأ)
tf.train.write_graph (sess.graph_def، '.'، 'المدربين_الموديل txt'، as_text = صواب)

في حالتي ، كانت السلاسل السحرية المطبوعة من save_def هي save / Const: 0 و save / rest_all - وهذا ما ستراه في كود Java الخاص بي. قم بتغييرها عندما تكتب كود Java الخاص بك إذا كانت مختلفة.

يحتوي ملف .sd على أوزان وتحيزات وما إلى ذلك (القيم الفعلية للمتغيرات في الرسم البياني الخاص بك). ملف .proto هو ملف ثنائي يحتوي على الرسم البياني لحسابك و .txt إصدار النص المقابل.

استدعاء Tensorflow C ++ من جافا

على الرغم من أنك قد تستخدم Tensorflow في Python لتغذية البيانات إلى النموذج الخاص بك وتدريبه ، فإن حزمة Tensorflow Python تستدعي بالفعل تطبيق C ++ لتنفيذ العمل الفعلي. لذلك ، يمكننا استخدام Java Native Interface (JNI) لاستدعاء C ++ مباشرةً واستخدام C ++ لإنشاء الرسم البياني واستعادة الأوزان والتحيزات من النموذج من Java.

بدلاً من كتابة جميع مكالمات JNI باليد ، من الممكن استخدام مكتبة مفتوحة المصدر تسمى JavaCpp للقيام بذلك. لاستخدام JavaCpp ، أضف هذه التبعية إلى Java Maven pom.xml:

<الاعتماد>
   org.bytedeco.javacpp-المسبقة 
   tensorflow 
  <إصدار> 0.9.0-1.2 

إذا كنت تستخدم بعض أنظمة إدارة الإنشاءات الأخرى ، فأضف الإعدادات المسبقة لـ Javacpp لـ tensorflow وكل تبعياته إلى مسار التطبيق الخاص بك.

إنشاء نموذج في جافا

في شفرة Java الخاصة بك ، اقرأ ملف proto لإنشاء تعريف Graph كما يلي (تم حذف عمليات الاستيراد للتوضيح):

الجلسة النهائية = الجلسة الجديدة (الجلسة الجديدة ()) ؛
GraphDef def = new GraphDef ()؛
tensorflow.ReadBinaryProto (Env.Default ()
                           "somedir / المدربين_الموظفين" ، "def" ؛
الحالة s = session.Create (def) ؛
إذا (! s.ok ()) {
    رمي جديد RuntimeException (s.error_message (). getString ()) ؛
}

بعد ذلك ، قم باستعادة الأوزان والتحيزات من ملف النموذج المحفوظ باستخدام Session :: Run (). لاحظ كيفية استخدام السلاسل السحرية من saver_def.

// استعادة
Tensor fn = new Tensor (tensorflow.DT_STRING ، TensorShape جديد (1)) ؛
StringArray a = fn.createStringArray ()؛
a.position (0). وضع ( "somedir / trained_model.sd")؛
s = session.Run (new StringTensorPairVector (new String [] {"save / Const: 0"} ، Tensor الجديد [] {fn}) ، StringVector جديدة () ، StringVector جديدة ("حفظ / استعادة_all") ، TensorVector جديدة ( ))؛
إذا (! s.ok ()) {
   رمي جديد RuntimeException (s.error_message (). getString ()) ؛
}

عمل تنبؤات في جافا

في هذه المرحلة ، نموذجك جاهز. يمكنك الآن استخدامه لعمل تنبؤات. هذا مشابه لكيفية القيام بذلك في Python - يجب عليك تمرير قيم لجميع العناصر النائبة الخاصة بك وتقييم عقدة المخرجات. الفرق هو أنه يجب عليك معرفة الأسماء الفعلية لعنصر نائب وإخراج. إذا لم تقم بتعيين أسماء فريدة لهذه العقد في Python ، فقد قام Tensorflow بتعيين أسماء لها. يمكنك معرفة ما هي عليه من خلال النظر في ملف المدربين. txt الذي كتب مكتوب. أو يمكنك الرجوع إلى رمز Python وتعيين أسماء عقد المفاتيح التي تتذكرها. في حالتي ، تم تسمية العنصر النائب للإدخال "العنصر النائب" ؛ تم استدعاء العنصر النائب عقدة التسرب Placeholder_2 ، وكان يسمى عقدة الإخراج Sigmoid. سترى هذه المشار إليها في المكالمة Session :: Run () أدناه.

في حالتي ، تستخدم الشبكة العصبية 5 متغيرات تنبؤية. على افتراض أن لدي مجموعة من المدخلات التي تنبئ بنموذج الشبكة العصبية وترغب في القيام بالتنبؤ بمجموعتين من هذه المدخلات ، فإن مدخلاتي هي مصفوفة 2 × 5. يحتوي NN الخاص بي على خرج واحد فقط ، لذلك بالنسبة لمجموعتي المدخلات ، يكون الموتر الناتج مصفوفة 2 × 1. يتم إعطاء عقدة التسرب من إدخال مدمج من 1.0 (في التنبؤ ، نحافظ على جميع العقد - احتمال التسرب هو فقط للتدريب). لذلك أنا أملك:

/ / حاول التنبؤ بمجموعتين (2) من المدخلات.
مدخلات الموتر = الموتر الجديد (
         tensorflow.DT_FLOAT ، TensorShape الجديد (2،5)) ؛
FloatBuffer x = inputs.createBuffer ()؛
x.put (new float [] {- ​​6.0f، 22.0f، 383.0f، 27.781754111198122f، -6.5f})؛
x.put (new float [] {66.0f، 22.0f، 2422.0f، 45.72160947712418f، 0.4f})؛
Tensor keepall = جديد Tensor (
        tensorflow.DT_FLOAT ، TensorShape الجديد (2،1)) ؛
((FloatBuffer) keepall.createBuffer ()). put (float new [] {1f، 1f})؛
مخرجات TensorVector = جديدة TensorVector () ؛
// للتنبؤ في كل مرة ، وتمرير القيم للعناصر النائبة
outputs.resize (0)؛
s = session.Run (جديد StringTensorPairVector (سلسلة جديدة [] {"Placeholder" ، "Placeholder_2"} ، Tensor جديد [] {inputs ، keepall}) ،
 StringVector جديدة ("Sigmoid") ، StringVector () جديدة ، مخرجات) ؛
إذا (! s.ok ()) {
   رمي جديد RuntimeException (s.error_message (). getString ()) ؛
}
/ / هذه هي الطريقة التي تسترجع بها القيمة المتوقعة من المخرجات
إخراج FloatBuffer = outputs.get (0) .createBuffer ()؛
لـ (int k = 0 ؛ k 

هذا كل شيء - أنت الآن تستخدم Java لتنفيذ تنبؤاتك. هناك العديد من الخطوات ، لكن هذا متوقع عند خلط ثلاث لغات برمجة (Python و C ++ و Java). ولكن الشيء المهم هو أنه يمكن القيام به ، وأنه واضح ومباشر.

بالطبع ، لا يستفيد هذا من تسريع الأجهزة وتوزيعها. إذا كنت ترغب في عمل تنبؤات بمعدل مرتفع جدًا في الوقت الفعلي ، فيجب عليك استخدام Cloud ML Engine.