summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoricebaker <icebaker@proton.me>2023-12-15 08:04:27 -0300
committericebaker <icebaker@proton.me>2023-12-15 08:04:27 -0300
commitfef3d5b3b2f823999fae68276382fe33872350c4 (patch)
treef41758bcc6c9e7e5805269afe4104a636dc17baa
parentb9a22a449d33d254f2c1a7f3d2196712ff6d9b8a (diff)
improving provider options
-rw-r--r--components/provider.rb2
-rw-r--r--components/providers/google.rb24
-rw-r--r--components/providers/openai.rb2
-rw-r--r--logic/cartridge/streaming.rb9
-rw-r--r--spec/data/cartridges/streaming.yml1
-rw-r--r--spec/logic/cartridge/streaming_spec.rb19
-rw-r--r--static/cartridges/default.yml2
7 files changed, 47 insertions, 12 deletions
diff --git a/components/provider.rb b/components/provider.rb
index 2ad35f4..d83319f 100644
--- a/components/provider.rb
+++ b/components/provider.rb
@@ -11,7 +11,7 @@ module NanoBot
when 'openai'
Providers::OpenAI.new(provider[:settings], provider[:credentials], environment:)
when 'google'
- Providers::Google.new(provider[:model], provider[:settings], provider[:credentials], environment:)
+ Providers::Google.new(provider[:options], provider[:settings], provider[:credentials], environment:)
else
raise "Unsupported provider \"#{provider[:id]}\""
end
diff --git a/components/providers/google.rb b/components/providers/google.rb
index 2a99bcb..f847677 100644
--- a/components/providers/google.rb
+++ b/components/providers/google.rb
@@ -14,13 +14,16 @@ module NanoBot
module Providers
class Google < Base
SETTINGS = {
- safetySettings: %i[category threshold].freeze,
- generationConfig: %i[temperature topP topK candidateCount maxOutputTokens stopSequences].freeze
+ generationConfig: %i[
+ temperature topP topK candidateCount maxOutputTokens stopSequences
+ ].freeze
}.freeze
+ SAFETY_SETTINGS = %i[category threshold].freeze
+
attr_reader :settings
- def initialize(model, settings, credentials, _environment)
+ def initialize(options, settings, credentials, _environment)
@settings = settings
@client = Gemini.new(
@@ -29,7 +32,7 @@ module NanoBot
project_id: credentials[:'project-id'],
region: credentials[:region]
},
- settings: { model:, stream: false }
+ settings: { model: options[:model], stream: options[:stream] }
)
end
@@ -77,6 +80,16 @@ module NanoBot
end
end
end
+
+ if @settings[:safetySettings].is_a?(Array)
+ payload[:safetySettings] = [] unless payload.key?(:safetySettings)
+
+ @settings[:safetySettings].each do |safety_setting|
+ setting = {}
+ SAFETY_SETTINGS.each { |key| setting[key] = safety_setting[key] }
+ payload[:safetySettings] << setting
+ end
+ end
end
if input[:tools]
@@ -143,7 +156,8 @@ module NanoBot
else
begin
result = @client.stream_generate_content(
- Logic::Google::Tokens.apply_policies!(cartridge, payload)
+ Logic::Google::Tokens.apply_policies!(cartridge, payload),
+ stream: false
)
rescue StandardError => e
raise e.class, e.response[:body] if e.response && e.response[:body]
diff --git a/components/providers/openai.rb b/components/providers/openai.rb
index b70984b..f6eafd4 100644
--- a/components/providers/openai.rb
+++ b/components/providers/openai.rb
@@ -18,7 +18,7 @@ module NanoBot
CHAT_SETTINGS = %i[
model stream temperature top_p n stop max_tokens
- presence_penalty frequency_penalty logit_bias
+ presence_penalty frequency_penalty logit_bias seed response_format
].freeze
attr_reader :settings
diff --git a/logic/cartridge/streaming.rb b/logic/cartridge/streaming.rb
index a0f8700..6949b3a 100644
--- a/logic/cartridge/streaming.rb
+++ b/logic/cartridge/streaming.rb
@@ -7,7 +7,14 @@ module NanoBot
module Cartridge
module Streaming
def self.enabled?(cartridge, interface)
- return false if Helpers::Hash.fetch(cartridge, %i[provider settings stream]) == false
+ provider_stream = case Helpers::Hash.fetch(cartridge, %i[provider id])
+ when 'openai'
+ Helpers::Hash.fetch(cartridge, %i[provider settings stream])
+ when 'google'
+ Helpers::Hash.fetch(cartridge, %i[provider options stream])
+ end
+
+ return false if provider_stream == false
specific_interface = Helpers::Hash.fetch(cartridge, [:interfaces, interface, :output, :stream])
diff --git a/spec/data/cartridges/streaming.yml b/spec/data/cartridges/streaming.yml
index 8234d34..e004110 100644
--- a/spec/data/cartridges/streaming.yml
+++ b/spec/data/cartridges/streaming.yml
@@ -10,5 +10,6 @@ interfaces:
stream: true
provider:
+ id: openai
settings:
stream: true
diff --git a/spec/logic/cartridge/streaming_spec.rb b/spec/logic/cartridge/streaming_spec.rb
index 466dd0b..4b71dfd 100644
--- a/spec/logic/cartridge/streaming_spec.rb
+++ b/spec/logic/cartridge/streaming_spec.rb
@@ -7,11 +7,22 @@ require_relative '../../../logic/cartridge/streaming'
RSpec.describe NanoBot::Logic::Cartridge::Streaming do
context 'interfaces override' do
context 'defaults' do
- let(:cartridge) { {} }
+ context 'openai' do
+ let(:cartridge) { { provider: { id: 'openai' } } }
- it 'uses default values when appropriate' do
- expect(described_class.enabled?(cartridge, :repl)).to be(true)
- expect(described_class.enabled?(cartridge, :eval)).to be(true)
+ it 'uses default values when appropriate' do
+ expect(described_class.enabled?(cartridge, :repl)).to be(true)
+ expect(described_class.enabled?(cartridge, :eval)).to be(true)
+ end
+ end
+
+ context 'google' do
+ let(:cartridge) { { provider: { id: 'google' } } }
+
+ it 'uses default values when appropriate' do
+ expect(described_class.enabled?(cartridge, :repl)).to be(true)
+ expect(described_class.enabled?(cartridge, :eval)).to be(true)
+ end
end
end
diff --git a/static/cartridges/default.yml b/static/cartridges/default.yml
index 98dd47b..fbf449b 100644
--- a/static/cartridges/default.yml
+++ b/static/cartridges/default.yml
@@ -30,5 +30,7 @@ interfaces:
feedback: true
provider:
+ options:
+ stream: true
settings:
stream: true