diff options
author | icebaker <icebaker@proton.me> | 2023-12-15 08:04:27 -0300 |
---|---|---|
committer | icebaker <icebaker@proton.me> | 2023-12-15 08:04:27 -0300 |
commit | fef3d5b3b2f823999fae68276382fe33872350c4 (patch) | |
tree | f41758bcc6c9e7e5805269afe4104a636dc17baa /components | |
parent | b9a22a449d33d254f2c1a7f3d2196712ff6d9b8a (diff) |
improving provider options
Diffstat (limited to 'components')
-rw-r--r-- | components/provider.rb | 2 | ||||
-rw-r--r-- | components/providers/google.rb | 24 | ||||
-rw-r--r-- | components/providers/openai.rb | 2 |
3 files changed, 21 insertions, 7 deletions
diff --git a/components/provider.rb b/components/provider.rb index 2ad35f4..d83319f 100644 --- a/components/provider.rb +++ b/components/provider.rb @@ -11,7 +11,7 @@ module NanoBot when 'openai' Providers::OpenAI.new(provider[:settings], provider[:credentials], environment:) when 'google' - Providers::Google.new(provider[:model], provider[:settings], provider[:credentials], environment:) + Providers::Google.new(provider[:options], provider[:settings], provider[:credentials], environment:) else raise "Unsupported provider \"#{provider[:id]}\"" end diff --git a/components/providers/google.rb b/components/providers/google.rb index 2a99bcb..f847677 100644 --- a/components/providers/google.rb +++ b/components/providers/google.rb @@ -14,13 +14,16 @@ module NanoBot module Providers class Google < Base SETTINGS = { - safetySettings: %i[category threshold].freeze, - generationConfig: %i[temperature topP topK candidateCount maxOutputTokens stopSequences].freeze + generationConfig: %i[ + temperature topP topK candidateCount maxOutputTokens stopSequences + ].freeze }.freeze + SAFETY_SETTINGS = %i[category threshold].freeze + attr_reader :settings - def initialize(model, settings, credentials, _environment) + def initialize(options, settings, credentials, _environment) @settings = settings @client = Gemini.new( @@ -29,7 +32,7 @@ module NanoBot project_id: credentials[:'project-id'], region: credentials[:region] }, - settings: { model:, stream: false } + settings: { model: options[:model], stream: options[:stream] } ) end @@ -77,6 +80,16 @@ module NanoBot end end end + + if @settings[:safetySettings].is_a?(Array) + payload[:safetySettings] = [] unless payload.key?(:safetySettings) + + @settings[:safetySettings].each do |safety_setting| + setting = {} + SAFETY_SETTINGS.each { |key| setting[key] = safety_setting[key] } + payload[:safetySettings] << setting + end + end end if input[:tools] @@ -143,7 +156,8 @@ module NanoBot else begin result = @client.stream_generate_content( - Logic::Google::Tokens.apply_policies!(cartridge, payload) + Logic::Google::Tokens.apply_policies!(cartridge, payload), + stream: false ) rescue StandardError => e raise e.class, e.response[:body] if e.response && e.response[:body] diff --git a/components/providers/openai.rb b/components/providers/openai.rb index b70984b..f6eafd4 100644 --- a/components/providers/openai.rb +++ b/components/providers/openai.rb @@ -18,7 +18,7 @@ module NanoBot CHAT_SETTINGS = %i[ model stream temperature top_p n stop max_tokens - presence_penalty frequency_penalty logit_bias + presence_penalty frequency_penalty logit_bias seed response_format ].freeze attr_reader :settings |