目录

效果

测试一

测试二

测试三

模型信息

项目

代码

下载


Inference with C# BERT NLP Deep Learning and ONNX Runtime

效果

测试一

Context :Bob is walking through the woods collecting blueberries and strawberries to make a pie.  

Question :What is his name?

Inference with C# BERT NLP Deep Learning and ONNX Runtime_深度学习

测试二

Context :Bob is walking through the woods collecting blueberries and strawberries to make a pie.  

Question :What will he bring home?

Inference with C# BERT NLP Deep Learning and ONNX Runtime_人工智能_02

测试三

Context :Bob is walking through the woods collecting blueberries and strawberries to make a pie.  

Question :Where is Bob?

Inference with C# BERT NLP Deep Learning and ONNX Runtime_深度学习_03

模型信息

Inputs
-------------------------
name:unique_ids_raw_output___9:0
tensor:Int64[-1]
name:segment_ids:0
tensor:Int64[-1, 256]
name:input_mask:0
tensor:Int64[-1, 256]
name:input_ids:0
tensor:Int64[-1, 256]
---------------------------------------------------------------

Outputs
-------------------------
name:unstack:1
tensor:Float[-1, 256]
name:unstack:0
tensor:Float[-1, 256]
name:unique_ids:0
tensor:Int64[-1]
---------------------------------------------------------------

项目

Inference with C# BERT NLP Deep Learning and ONNX Runtime_c#_04

代码

using BERTTokenizers;
 using Microsoft.ML.OnnxRuntime;
 using System;
 using System.Collections.Generic;
 using System.Data;
 using System.Diagnostics;
 using System.Linq;
 using System.Windows.Forms;namespace Inference_with_C__BERT_NLP_Deep_Learning_and_ONNX_Runtime
 {
     public struct BertInput
     {
         public long[] InputIds { get; set; }
         public long[] InputMask { get; set; }
         public long[] SegmentIds { get; set; }
         public long[] UniqueIds { get; set; }
     }    public partial class Form1 : Form
     {
         public Form1()
         {
             InitializeComponent();
         }        RunOptions runOptions;
         InferenceSession session;
         BertUncasedLargeTokenizer tokenizer;
         Stopwatch stopWatch = new Stopwatch();        private void Form1_Load(object sender, EventArgs e)
         {
             string modelPath = "bertsquad-10.onnx";
             runOptions = new RunOptions();
             session = new InferenceSession(modelPath);
             tokenizer = new BertUncasedLargeTokenizer();
         }        int MaxAnswerLength = 30;
         int bestN = 20;        private void button1_Click(object sender, EventArgs e)
         {
             txt_answer.Text = "";
             Application.DoEvents();            string question = txt_question.Text.Trim();
             string context = txt_context.Text.Trim();            // Get the sentence tokens.
             var tokens = tokenizer.Tokenize(question, context);            // Encode the sentence and pass in the count of the tokens in the sentence.
             var encoded = tokenizer.Encode(tokens.Count(), question, context);            var padding = Enumerable
               .Repeat(0L, 256 - tokens.Count)
               .ToList();            var bertInput = new BertInput()
             {
                 InputIds = encoded.Select(t => t.InputIds).Concat(padding).ToArray(),
                 InputMask = encoded.Select(t => t.AttentionMask).Concat(padding).ToArray(),
                 SegmentIds = encoded.Select(t => t.TokenTypeIds).Concat(padding).ToArray(),
                 UniqueIds = new long[] { 0 }
             };            // Create input tensors over the input data.
             var inputIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.InputIds,
                   new long[] { 1, bertInput.InputIds.Length });            var inputMaskOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.InputMask,
                   new long[] { 1, bertInput.InputMask.Length });            var segmentIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.SegmentIds,
                   new long[] { 1, bertInput.SegmentIds.Length });            var uniqueIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.UniqueIds,
                   new long[] { bertInput.UniqueIds.Length });            var inputs = new Dictionary<string, OrtValue>
               {
                   { "unique_ids_raw_output___9:0", uniqueIdsOrtValue },
                   { "segment_ids:0", segmentIdsOrtValue},
                   { "input_mask:0", inputMaskOrtValue },
                   { "input_ids:0", inputIdsOrtValue }
               };            stopWatch.Restart();
             // Run session and send the input data in to get inference output. 
             var output = session.Run(runOptions, inputs, session.OutputNames);
             stopWatch.Stop();            var startLogits = output[1].GetTensorDataAsSpan<float>();
            var endLogits = output[0].GetTensorDataAsSpan<float>();
            var uniqueIds = output[2].GetTensorDataAsSpan<long>();
            var contextStart = tokens.FindIndex(o => o.Token == "[SEP]");
            var bestStartLogits = startLogits.ToArray()
                 .Select((logit, index) => (Logit: logit, Index: index))
                 .OrderByDescending(o => o.Logit)
                 .Take(bestN);            var bestEndLogits = endLogits.ToArray()
                 .Select((logit, index) => (Logit: logit, Index: index))
                 .OrderByDescending(o => o.Logit)
                 .Take(bestN);            var bestResultsWithScore = bestStartLogits
                 .SelectMany(startLogit =>
                     bestEndLogits
                     .Select(endLogit =>
                         (
                             StartLogit: startLogit.Index,
                             EndLogit: endLogit.Index,
                             Score: startLogit.Logit + endLogit.Logit
                         )
                      )
                 )
                 .Where(entry => !(entry.EndLogit < entry.StartLogit || entry.EndLogit - entry.StartLogit > MaxAnswerLength || entry.StartLogit == 0 && entry.EndLogit == 0 || entry.StartLogit < contextStart))
                 .Take(bestN);            var (item, probability) = bestResultsWithScore
                 .Softmax(o => o.Score)
                 .OrderByDescending(o => o.Probability)
                 .FirstOrDefault();            int startIndex = item.StartLogit;
             int endIndex = item.EndLogit;            var predictedTokens = tokens
                           .Skip(startIndex)
                           .Take(endIndex + 1 - startIndex)
                           .Select(o => tokenizer.IdToToken((int)o.VocabularyIndex))
                           .ToList();            // Print the result.
             string answer = "answer:" + String.Join(" ", StitchSentenceBackTogether(predictedTokens))
                 + "\r\nprobability:" + probability
                 + $"\r\n推理耗时:{stopWatch.ElapsedMilliseconds}毫秒";            txt_answer.Text = answer;
             Console.WriteLine(answer);        }
        private List<string> StitchSentenceBackTogether(List<string> tokens)
         {
             var currentToken = string.Empty;            tokens.Reverse();
            var tokensStitched = new List<string>();
            foreach (var token in tokens)
             {
                 if (!token.StartsWith("##"))
                 {
                     currentToken = token + currentToken;
                     tokensStitched.Add(currentToken);
                     currentToken = string.Empty;
                 }
                 else
                 {
                     currentToken = token.Replace("##", "") + currentToken;
                 }
             }            tokensStitched.Reverse();
            return tokensStitched;
         }
     }
 }


 

using BERTTokenizers;
using Microsoft.ML.OnnxRuntime;
using System;
using System.Collections.Generic;
using System.Data;
using System.Diagnostics;
using System.Linq;
using System.Windows.Forms;

namespace Inference_with_C__BERT_NLP_Deep_Learning_and_ONNX_Runtime
{
    public struct BertInput
    {
        public long[] InputIds { get; set; }
        public long[] InputMask { get; set; }
        public long[] SegmentIds { get; set; }
        public long[] UniqueIds { get; set; }
    }

    public partial class Form1 : Form
    {
        public Form1()
        {
            InitializeComponent();
        }

        RunOptions runOptions;
        InferenceSession session;
        BertUncasedLargeTokenizer tokenizer;
        Stopwatch stopWatch = new Stopwatch();

        private void Form1_Load(object sender, EventArgs e)
        {
            string modelPath = "bertsquad-10.onnx";
            runOptions = new RunOptions();
            session = new InferenceSession(modelPath);
            tokenizer = new BertUncasedLargeTokenizer();
        }

        int MaxAnswerLength = 30;
        int bestN = 20;

        private void button1_Click(object sender, EventArgs e)
        {
            txt_answer.Text = "";
            Application.DoEvents();

            string question = txt_question.Text.Trim();
            string context = txt_context.Text.Trim();

            // Get the sentence tokens.
            var tokens = tokenizer.Tokenize(question, context);

            // Encode the sentence and pass in the count of the tokens in the sentence.
            var encoded = tokenizer.Encode(tokens.Count(), question, context);

            var padding = Enumerable
              .Repeat(0L, 256 - tokens.Count)
              .ToList();

            var bertInput = new BertInput()
            {
                InputIds = encoded.Select(t => t.InputIds).Concat(padding).ToArray(),
                InputMask = encoded.Select(t => t.AttentionMask).Concat(padding).ToArray(),
                SegmentIds = encoded.Select(t => t.TokenTypeIds).Concat(padding).ToArray(),
                UniqueIds = new long[] { 0 }
            };

            // Create input tensors over the input data.
            var inputIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.InputIds,
                  new long[] { 1, bertInput.InputIds.Length });

            var inputMaskOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.InputMask,
                  new long[] { 1, bertInput.InputMask.Length });

            var segmentIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.SegmentIds,
                  new long[] { 1, bertInput.SegmentIds.Length });

            var uniqueIdsOrtValue = OrtValue.CreateTensorValueFromMemory(bertInput.UniqueIds,
                  new long[] { bertInput.UniqueIds.Length });

            var inputs = new Dictionary<string, OrtValue>
              {
                  { "unique_ids_raw_output___9:0", uniqueIdsOrtValue },
                  { "segment_ids:0", segmentIdsOrtValue},
                  { "input_mask:0", inputMaskOrtValue },
                  { "input_ids:0", inputIdsOrtValue }
              };

            stopWatch.Restart();
            // Run session and send the input data in to get inference output. 
            var output = session.Run(runOptions, inputs, session.OutputNames);
            stopWatch.Stop();

            var startLogits = output[1].GetTensorDataAsSpan<float>();

            var endLogits = output[0].GetTensorDataAsSpan<float>();

            var uniqueIds = output[2].GetTensorDataAsSpan<long>();

            var contextStart = tokens.FindIndex(o => o.Token == "[SEP]");

            var bestStartLogits = startLogits.ToArray()
                .Select((logit, index) => (Logit: logit, Index: index))
                .OrderByDescending(o => o.Logit)
                .Take(bestN);

            var bestEndLogits = endLogits.ToArray()
                .Select((logit, index) => (Logit: logit, Index: index))
                .OrderByDescending(o => o.Logit)
                .Take(bestN);

            var bestResultsWithScore = bestStartLogits
                .SelectMany(startLogit =>
                    bestEndLogits
                    .Select(endLogit =>
                        (
                            StartLogit: startLogit.Index,
                            EndLogit: endLogit.Index,
                            Score: startLogit.Logit + endLogit.Logit
                        )
                     )
                )
                .Where(entry => !(entry.EndLogit < entry.StartLogit || entry.EndLogit - entry.StartLogit > MaxAnswerLength || entry.StartLogit == 0 && entry.EndLogit == 0 || entry.StartLogit < contextStart))
                .Take(bestN);

            var (item, probability) = bestResultsWithScore
                .Softmax(o => o.Score)
                .OrderByDescending(o => o.Probability)
                .FirstOrDefault();

            int startIndex = item.StartLogit;
            int endIndex = item.EndLogit;

            var predictedTokens = tokens
                          .Skip(startIndex)
                          .Take(endIndex + 1 - startIndex)
                          .Select(o => tokenizer.IdToToken((int)o.VocabularyIndex))
                          .ToList();

            // Print the result.
            string answer = "answer:" + String.Join(" ", StitchSentenceBackTogether(predictedTokens))
                + "\r\nprobability:" + probability
                + $"\r\n推理耗时:{stopWatch.ElapsedMilliseconds}毫秒";

            txt_answer.Text = answer;
            Console.WriteLine(answer);

        }

        private List<string> StitchSentenceBackTogether(List<string> tokens)
        {
            var currentToken = string.Empty;

            tokens.Reverse();

            var tokensStitched = new List<string>();

            foreach (var token in tokens)
            {
                if (!token.StartsWith("##"))
                {
                    currentToken = token + currentToken;
                    tokensStitched.Add(currentToken);
                    currentToken = string.Empty;
                }
                else
                {
                    currentToken = token.Replace("##", "") + currentToken;
                }
            }

            tokensStitched.Reverse();

            return tokensStitched;
        }
    }
}

下载

源码下载