close
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions server/handler/experiments.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,57 @@ func CreateExperiment(repo *repository.Experiments) RequestProcessFunc {
return serializer.NewExperimentResponse(experiment, 0), nil
}
}

type updateExperimentReq struct {
Name string `json:"name"`
Description string `json:"description"`
}

// UpdateExperiment returns a function that updates the experiment as passed in the body request
func UpdateExperiment(repo *repository.Experiments, assignmentsRepo *repository.Assignments) RequestProcessFunc {
return func(r *http.Request) (*serializer.Response, error) {
userID, err := service.GetUserID(r.Context())
if err != nil {
return nil, err
}

experimentID, err := urlParamInt(r, "experimentId")
if err != nil {
return nil, err
}

experiment, err := repo.GetByID(experimentID)
if err != nil {
return nil, err
}
if experiment == nil {
return nil, serializer.NewHTTPError(http.StatusNotFound, "no experiment found")
}

var updateExperimentReq updateExperimentReq
body, err := ioutil.ReadAll(r.Body)
if err != nil {
return nil, serializer.NewHTTPError(http.StatusBadRequest, err.Error())
}

err = json.Unmarshal(body, &updateExperimentReq)
if err != nil {
return nil, serializer.NewHTTPError(http.StatusBadRequest, err.Error())
}

experiment.Name = updateExperimentReq.Name
experiment.Description = updateExperimentReq.Description

err = repo.Update(experiment)
if err != nil {
return nil, err
}

progress, err := experimentProgress(assignmentsRepo, experiment.ID, userID)
if err != nil {
return nil, err
}

return serializer.NewExperimentResponse(experiment, progress), nil
}
}
29 changes: 29 additions & 0 deletions server/handler/experiments_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,32 @@ func TestCreateExperiment(t *testing.T) {
Description: "test",
}, 0), res)
}

func TestUpdateExperiment(t *testing.T) {
assert := assert.New(t)

db := testDB()
repo := repository.NewExperiments(db.DB)
assignmentsRepo := repository.NewAssignments(db.DB)
handler := handler.UpdateExperiment(repo, assignmentsRepo)

json := `{"name": "new", "description": "test"}`
req, _ := http.NewRequest("PUT", "/experiments/1", strings.NewReader(json))
req = chiRequest(req, map[string]string{"experimentId": "1"})
req = reqWithUser(req, 1)
res, err := handler(req)
assert.Nil(err)

assert.Equal(serializer.NewExperimentResponse(&model.Experiment{
ID: 1,
Name: "new",
Description: "test",
}, 0), res)

req, _ = http.NewRequest("PUT", "/experiments/2", strings.NewReader(json))
req = chiRequest(req, map[string]string{"experimentId": "2"})
req = reqWithUser(req, 1)
res, err = handler(req)
assert.Nil(res)
assert.Equal(serializer.NewHTTPError(http.StatusNotFound, "no experiment found"), err)
}
6 changes: 6 additions & 0 deletions server/handler/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/go-chi/chi"
"github.com/pressly/lg"
"github.com/src-d/code-annotation/server/dbutil"
"github.com/src-d/code-annotation/server/service"
)

func testDB() *dbutil.DB {
Expand Down Expand Up @@ -41,3 +42,8 @@ func chiRequest(req *http.Request, params map[string]string) *http.Request {

return req.WithContext(ctx)
}

func reqWithUser(req *http.Request, userID int) *http.Request {
ctx := service.SetUserID(req.Context(), userID)
return req.WithContext(ctx)
}
7 changes: 7 additions & 0 deletions server/repository/experiments.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ func (repo *Experiments) getWithQuery(queryRow scannable) (*model.Experiment, er
const selectExperimentsWhereIDSQL = `SELECT * FROM experiments WHERE id=$1`
const selectExperimentsSQL = `SELECT * FROM experiments`
const insertExperimentSQL = `INSERT INTO experiments (name, description) VALUES ($1, $2)`
const updateExperimentSQL = `UPDATE experiments SET name=$1, description=$2 WHERE id=$3`

// GetByID returns the Experiment with the given ID. If the Experiment does not
// exist, it returns nil, nil
Expand Down Expand Up @@ -86,3 +87,9 @@ func (repo *Experiments) Create(m *model.Experiment) error {

return nil
}

// Update experiment model in database
func (repo *Experiments) Update(m *model.Experiment) error {
_, err := repo.db.Exec(updateExperimentSQL, m.Name, m.Description, m.ID)
return err
}
2 changes: 2 additions & 0 deletions server/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ func Router(
r.Route("/experiments/{experimentId}", func(r chi.Router) {

r.Get("/", handler.APIHandlerFunc(handler.GetExperimentDetails(experimentRepo, assignmentRepo)))
r.With(requesterACL.Middleware).
Put("/", handler.APIHandlerFunc(handler.UpdateExperiment(experimentRepo, assignmentRepo)))

r.Route("/assignments", func(r chi.Router) {

Expand Down
7 changes: 6 additions & 1 deletion server/service/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (j *JWT) Middleware(next http.Handler) http.Handler {
w.WriteHeader(http.StatusUnauthorized)
return
}
r = r.WithContext(context.WithValue(r.Context(), userIDKey, claims.ID))
r = r.WithContext(SetUserID(r.Context(), claims.ID))
next.ServeHTTP(w, r)
})
}
Expand All @@ -87,6 +87,11 @@ func getUserInt(ctx context.Context) (int, bool) {
return i, ok
}

// SetUserID sets the user ID to the context
func SetUserID(ctx context.Context, userID int) context.Context {
return context.WithValue(ctx, userIDKey, userID)
}

// GetUserID gets the user ID set by the JWT middleware in the Context
func GetUserID(ctx context.Context) (int, error) {
id, ok := getUserInt(ctx)
Expand Down