This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
class MultiQueryAttention(nn.Module): | |
def __init__(self, embed_dim, num_heads): | |
super().__init__() | |
self.embed_dim = embed_dim | |
self.num_heads = num_heads | |
self.head_dim = embed_dim // num_heads | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import time | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
device = "cpu" | |
tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
model = AutoModelForCausalLM.from_pretrained("gpt2").to(device) | |
dur = {} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Tensor: | |
def __init__(self, value, requires_grad=False): | |
self.value = value # Store the value of the tensor | |
self.grad = 0 # Gradient initialized to zero | |
self.requires_grad = requires_grad | |
self._backward = lambda: None # Function to compute gradient | |
self._prev = set() # Track previous nodes for backpropagation | |
def backward(self): | |
"""Computes gradients using reverse-mode automatic differentiation.""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
# Load GPT-2 model and tokenizer | |
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
model = GPT2LMHeadModel.from_pretrained("gpt2") | |
model.eval() | |
def compare_outputs(prompt): | |
print('Prompt:', prompt) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Results | |
## Feature 1: Biden's health and cognitive abilities. | |
### Stats | |
F1 score: 0.5 | |
Pearson Correlation: -0.38014296063485287 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
# Parameters | |
T = 35 # Number of trials | |
beta = 0.5 # Probability parameter for z | |
alpha_1 = 0.2 # Probability parameter for x if z=0 | |
alpha_2 = 0.6 # Probability parameter for x if z=1 | |
# Simulate z |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_response_sql(user_query, chat_history, plot=False): | |
# Specify the path to the SQLite database | |
db_path = "metadataDB/output_database.db" | |
# Connect to the SQLite database | |
db = SQLDatabase.from_uri(f"sqlite:///{db_path}") | |
underspecified = classify_underspecified_query(user_query , chat_history) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_response_sql(user_query, chat_history, plot=False): | |
# Specify the path to the SQLite database | |
db_path = "metadataDB/output_database.db" | |
# Connect to the SQLite database | |
db = SQLDatabase.from_uri(f"sqlite:///{db_path}") | |
underspecified = classify_underspecified_query(user_query , chat_history) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
from typing import Iterable | |
from collections import Counter | |
import os | |
import logging | |
import sys | |
import json | |
import click | |
import datasets | |
import numpy as np |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import argparse | |
from typing import Optional, Union, Tuple | |
import torch | |
torch.manual_seed(0) | |
from transformers import BertModel, BertTokenizer, PreTrainedModel, BertConfig | |
from transformers.modeling_outputs import MultipleChoiceModelOutput |
NewerOlder