diff options
author | icebaker <icebaker@proton.me> | 2023-05-11 19:24:50 -0300 |
---|---|---|
committer | icebaker <icebaker@proton.me> | 2023-05-11 19:24:50 -0300 |
commit | ec5e25547a401141586c87621266f9cd68c59e3c (patch) | |
tree | 547b3c7fa04c9e695785b9beeda0be5a4a77b006 /components |
first commit
Diffstat (limited to 'components')
-rw-r--r-- | components/provider.rb | 20 | ||||
-rw-r--r-- | components/providers/base.rb | 15 | ||||
-rw-r--r-- | components/providers/openai.rb | 79 |
3 files changed, 114 insertions, 0 deletions
diff --git a/components/provider.rb b/components/provider.rb new file mode 100644 index 0000000..dbfc8bd --- /dev/null +++ b/components/provider.rb @@ -0,0 +1,20 @@ +# frozen_string_literal: true + +require 'openai' + +require_relative './providers/openai' + +module NanoBot + module Components + class Provider + def self.new(provider) + case provider[:name] + when 'openai' + Providers::OpenAI.new(provider[:settings]) + else + raise "Unsupported provider #{provider[:name]}" + end + end + end + end +end diff --git a/components/providers/base.rb b/components/providers/base.rb new file mode 100644 index 0000000..011c5dd --- /dev/null +++ b/components/providers/base.rb @@ -0,0 +1,15 @@ +# frozen_string_literal: true + +require 'openai' + +module NanoBot + module Components + module Providers + class Base + def evaluate(_payload) + raise NoMethodError, "The 'evaluate' method is not implemented for the current provider." + end + end + end + end +end diff --git a/components/providers/openai.rb b/components/providers/openai.rb new file mode 100644 index 0000000..e163573 --- /dev/null +++ b/components/providers/openai.rb @@ -0,0 +1,79 @@ +# frozen_string_literal: true + +require 'openai' + +require_relative './base' + +module NanoBot + module Components + module Providers + class OpenAI < Base + CHAT_SETTINGS = %i[ + model stream temperature top_p n stop max_tokens + presence_penalty frequency_penalty logit_bias + ].freeze + + attr_reader :settings + + def initialize(settings) + @settings = settings + + @client = ::OpenAI::Client.new( + uri_base: "#{@settings[:credentials][:address].sub(%r{/$}, '')}/", + access_token: @settings[:credentials][:'access-token'] + ) + end + + def evaluate(input, &block) + messages = input[:history].map do |event| + { role: event[:who] == 'user' ? 'user' : 'assistant', + content: event[:message] } + end + + %i[instruction backdrop directive].each do |key| + next unless input[:behavior][key] + + messages.prepend( + { role: key == :directive ? 'system' : 'user', + content: input[:behavior][key] } + ) + end + + payload = { + model: @settings[:model], + user: @settings[:credentials][:'user-identifier'], + messages: + } + + CHAT_SETTINGS.each do |key| + payload[key] = @settings[key] if @settings.key?(key) + end + + payload.delete(:logit_bias) if payload.key?(:logit_bias) && payload[:logit_bias].nil? + + if @settings[:stream] && input[:interface][:stream] + content = '' + + payload[:stream] = proc do |chunk, _bytesize| + partial = chunk.dig('choices', 0, 'delta', 'content') + if partial + content += partial + block.call({ who: 'AI', message: partial }, false) + end + + block.call({ who: 'AI', message: content }, true) if chunk.dig('choices', 0, 'finish_reason') + end + + @client.chat(parameters: payload) + else + result = @client.chat(parameters: payload) + + raise StandardError, result['error'] if result['error'] + + block.call({ who: 'AI', message: result.dig('choices', 0, 'message', 'content') }, true) + end + end + end + end + end +end |